DeZero で眺める自動微分 - らんだむな記憶 をソースコード側から眺める。
diff --git a/dezero/core.py b/dezero/core.py index 04985c9..a2a2c6b 100644 --- a/dezero/core.py +++ b/dezero/core.py @@ -38,6 +38,21 @@ try: except ImportError: array_types = (np.ndarray) +class UniqueId: + def __init__(self): + self.ids = {} + self.id = 1 + + def __call__(self, obj): + hash_ = hash(obj) + if hash_ in self.ids: + return self.ids[hash_] + new_id = self.id + self.ids[hash_] = new_id + self.id += 1 + return new_id + +get_id = UniqueId() class Variable: __array_priority__ = 200 @@ -92,6 +107,7 @@ class Variable: if self.grad is None: xp = dezero.cuda.get_array_module(self.data) self.grad = Variable(xp.ones_like(self.data)) + print(f"[{get_id(self)}] {self.__class__.__name__} {self.grad}") funcs = [] seen_set = set() @@ -115,7 +131,9 @@ class Variable: for x, gx in zip(f.inputs, gxs): if x.grad is None: x.grad = gx + print(f"[{get_id(x)}] {f} {x.grad.data}") else: + print(f"[{get_id(x)}] {f} {x.grad.data} is added by {gx.data} -> {x.grad.data + gx.data}") x.grad = x.grad + gx if x.creator is not None: @@ -205,6 +223,9 @@ class Function: def backward(self, gys): raise NotImplementedError() + def __str__(self): + return self.__class__.__name__ + # ============================================================================= # 四則演算 / 演算子のオーバーロード
といった変更を入れて、
def main(): x = dezero.Variable(np.array(3)) y = x**2 + x**3 y.backward() print(y, x.grad)
というコードを実行する。すると、
[1] Variable variable(1) [2] Add 1 [3] Add 1 [4] Pow 27 [4] Pow 27 is added by 6 -> 33 variable(36) variable(33)
という感じになるので、まぁ、そんな感じかなと。
もっと詳しく見るために、以下のように変更する。
diff --git a/dezero/core.py b/dezero/core.py index 04985c9..70be3c4 100644 --- a/dezero/core.py +++ b/dezero/core.py @@ -2,6 +2,7 @@ import weakref import numpy as np import contextlib import dezero +import inspect # ============================================================================= @@ -38,6 +39,21 @@ try: except ImportError: array_types = (np.ndarray) +class UniqueId: + def __init__(self): + self.ids = {} + self.id = 1 + + def __call__(self, obj): + hash_ = hash(obj) + if hash_ in self.ids: + return self.ids[hash_] + new_id = self.id + self.ids[hash_] = new_id + self.id += 1 + return new_id + +get_id = UniqueId() class Variable: __array_priority__ = 200 @@ -47,6 +63,12 @@ class Variable: if not isinstance(data, array_types): raise TypeError('{} is not supported'.format(type(data))) + if name == "tmp" or name == "ones_like": + print(f"[{get_id(self)}] '{data}' created", end="") + name = None + else: + print(f"[{get_id(self)}] '{data}' created") + self.data = data self.name = name self.grad = None @@ -91,7 +113,8 @@ class Variable: def backward(self, retain_grad=False, create_graph=False): if self.grad is None: xp = dezero.cuda.get_array_module(self.data) - self.grad = Variable(xp.ones_like(self.data)) + self.grad = Variable(xp.ones_like(self.data), name="ones_like") + print(f" in backward {self.grad} ") funcs = [] seen_set = set() @@ -115,7 +138,9 @@ class Variable: for x, gx in zip(f.inputs, gxs): if x.grad is None: x.grad = gx + print(f"[{get_id(x)}] {f} {x.grad.data}") else: + print(f"[{get_id(x)}] {f} {x.grad.data} is added by {gx.data} -> {x.grad.data + gx.data}") x.grad = x.grad + gx if x.creator is not None: @@ -188,7 +213,8 @@ class Function: ys = self.forward(*xs) if not isinstance(ys, tuple): ys = (ys,) - outputs = [Variable(as_array(y)) for y in ys] + outputs = [Variable(as_array(y), name="tmp") for y in ys] + print(f" by {inspect.stack()[1].function}") if Config.enable_backprop: self.generation = max([x.generation for x in inputs]) @@ -205,6 +231,9 @@ class Function: def backward(self, gys): raise NotImplementedError() + def __str__(self): + return self.__class__.__name__ + # ============================================================================= # 四則演算 / 演算子のオーバーロード
すると以下のようなログになる。かなり多くの中間オブジェクトとしてのテンソルが作られていることが分かる。
[1] '3' created [2] '9' created by pow [3] '27' created by pow [4] '36' created by add [5] '1' created in backward variable(1) [2] Add 1 [3] Add 1 [6] '9' created by pow [7] '3' created [8] '27' created by mul [7] '27' created by mul [1] Pow 27 [6] '3' created by pow [9] '2' created [10] '6' created by mul [9] '6' created by mul [1] Pow 27 is added by 6 -> 33 [11] '33' created by add variable(36) variable(33)