查看原文
其他

【强基固本】100行用Python实现自动求导(不import任何包的情况下)

“强基固本,行稳致远”,科学研究离不开理论基础,人工智能学科更是需要数学、物理、神经科学等基础学科提供有力支撑,为了紧扣时代脉搏,我们推出“强基固本”专栏,讲解AI领域的基础知识,为你的科研学习提供助力,夯实理论基础,提升原始创新能力,敬请关注。

来源:知乎—而今听雨

地址:https://zhuanlan.zhihu.com/p/438685307

因为看到了这个问题:
https://www.zhihu.com/question/501734446
觉得好玩就手撸了一个自动求导,一句import都没有,连builtin包都不引用的那种
原理很简单,构建一个树,树的leave是变量和常量,其他包括root在内的结点都是运算符结点
把常用运算符(四则运算、常用函数之类的)计算和求导都写在一个类里,就OK了
过于复杂的式子应该先跑一个DFS,把每个结点的gradient存起来避免重复计算,简单式子就无所谓了,写着玩的~
# 为了不import,我甚至连product都是自己写的=_=def product(items): res = 1 for i in items: res = res * i return res
# 继承关系:# Node <- Constant, Variable, Operator# Operator <- Add, Multiply, Divide, Pow, ...
# 所有的Operator都有子节点,所有的Constant和Variable都没有子结点class Node: def __init__(self, name, value=0): self.name = name self.value = value
def __eq__(self, other): return self.name == other.name
def __str__(self): return str(self.name)
def __repr__(self): return self.__str__()

class Constant(Node): def __init__(self, value): super().__init__(value, value)
def compute_value(self): return self.value
def compute_derivative(self, to_variable): return 0

class Variable(Node): def compute_value(self): return self.value
def compute_derivative(self, to_variable): if to_variable.name == self.name: return 1 else: return 0

class Operator(Node): def __init__(self, inputs, name): self.inputs = inputs self.name = f"Opt {name} of {inputs}"
def __str__(self): opt2str = {"Add": "+", "Power": "^", "Multiply": "*", "Divide": "/"} return "(" + opt2str[self.name.split(" ")[1]].join(map(str, self.inputs)) + ")"

class Add(Operator): def __init__(self, inputs): super().__init__(inputs, name="Add")
def compute_value(self): return sum(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum(inp.compute_derivative(to_variable) for inp in self.inputs)

class Multiply(Operator): def __init__(self, inputs): super().__init__(inputs, name="Multiply")
def compute_value(self): return product(inp.compute_value() for inp in self.inputs)
def compute_derivative(self, to_variable): return sum( inp.compute_derivative(to_variable) * product( other_inp.compute_value() for other_inp in self.inputs if other_inp != inp ) for inp in self.inputs )

class Divide(Operator): def __init__(self, inputs): super().__init__(inputs, name="Divide")
def compute_value(self): a, b = [inp.compute_value() for inp in self.inputs] return a / b
def compute_derivative(self, to_variable): a, b = [inp.compute_value() for inp in self.inputs] da, db = [inp.compute_derivative(to_variable) for inp in self.inputs] return (da * b - db * a) / (b ** 2)

class Power(Operator): # Constant Power def __init__(self, inputs): super().__init__(inputs, name="Power")
def compute_value(self): x, n = self.inputs n = n.value return x.compute_value() ** n
def compute_derivative(self, to_variable): x, n = self.inputs n = n.value return n * (x.compute_value() ** (n - 1)) * x.compute_derivative(to_variable)

if __name__ == "__main__":    print(Add([Varaible("x"),Constant(5)]).compute_derivative())
到这里就可以work了,不过构建每个项和运算符都要实例化一个类,着实是麻烦,可以通过重写所有结点的运算符的方法来更优雅地构建较长的式子,例如像是这种:3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
def wrapper_opt(opt, self, other, r=False): opt2class = {"add": Add, "mul": Multiply, "pow": Power, "div": Divide} if not isinstance(other, Node): other = Constant(other) inputs = [other, self] if r else [self, other] node = opt2class[opt](inputs=inputs) return node

Node.__add__ = lambda self, other: wrapper_opt("add", self, other)Node.__mul__ = lambda self, other: wrapper_opt("mul", self, other)Node.__truediv__ = lambda self, other: wrapper_opt("div", self, other)Node.__pow__ = lambda self, other: wrapper_opt("pow", self, other)Node.__sub__ = lambda self, other: wrapper_opt( "add", self, wrapper_opt("mul", Constant(-1), other))Node.__radd__ = lambda self, other: wrapper_opt("add", self, other, r=True)Node.__rmul__ = lambda self, other: wrapper_opt("mul", self, other, r=True)Node.__rtruediv__ = lambda self, other: wrapper_opt("div", self, other, r=True)

if __name__ == "__main__": x = Variable(name="x") y = Variable(name="y") function = 3 * (x ** 2) + 5 * x * y + 6 / x - 8 * y ** 2 + 10
x.value = 18 y.value = 2 print(function.compute_value()) print(function.compute_derivative(x)) print(function)
小tip:把减法a-b定义成a+(-1*b)可以省去一个Sub运算符。其实同理除法a/b也可以定义成a*pow(b,-1)。
输出为:
5.0
1130.3333333333333
(((((3*(x^2))+((5*x)*y))+(6/x))+(-1*(8*(y^2))))+10)

本文目的在于学术交流,并不代表本公众号赞同其观点或对其内容真实性负责,版权归原作者所有,如有侵权请告知删除。

“强基固本”历史文章


更多强基固本专栏文章,

请点击文章底部“阅读原文”查看



分享、点赞、在看,给个三连击呗!

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存