らんだむな記憶

blogというものを体験してみようか!的なー

DeZero で眺める自動微分 (2)

DeZero で眺める自動微分 - らんだむな記憶ソースコード側から眺める。

diff --git a/dezero/core.py b/dezero/core.py
index 04985c9..a2a2c6b 100644
--- a/dezero/core.py
+++ b/dezero/core.py
@@ -38,6 +38,21 @@ try:
 except ImportError:
     array_types = (np.ndarray)
 
+class UniqueId:
+    def __init__(self):
+        self.ids = {}
+        self.id = 1
+
+    def __call__(self, obj):
+        hash_ = hash(obj)
+        if hash_ in self.ids:
+            return self.ids[hash_]
+        new_id = self.id
+        self.ids[hash_] = new_id
+        self.id += 1
+        return new_id
+
+get_id = UniqueId()
 
 class Variable:
     __array_priority__ = 200
@@ -92,6 +107,7 @@ class Variable:
         if self.grad is None:
             xp = dezero.cuda.get_array_module(self.data)
             self.grad = Variable(xp.ones_like(self.data))
+            print(f"[{get_id(self)}] {self.__class__.__name__} {self.grad}")
 
         funcs = []
         seen_set = set()
@@ -115,7 +131,9 @@ class Variable:
                 for x, gx in zip(f.inputs, gxs):
                     if x.grad is None:
                         x.grad = gx
+                        print(f"[{get_id(x)}] {f} {x.grad.data}")
                     else:
+                        print(f"[{get_id(x)}] {f} {x.grad.data} is added by {gx.data} -> {x.grad.data + gx.data}")
                         x.grad = x.grad + gx
 
                     if x.creator is not None:
@@ -205,6 +223,9 @@ class Function:
     def backward(self, gys):
         raise NotImplementedError()
 
+    def __str__(self):
+        return self.__class__.__name__
+
 
 # =============================================================================
 # 四則演算 / 演算子のオーバーロード

といった変更を入れて、

def main():
    x = dezero.Variable(np.array(3))
    y = x**2 + x**3
    y.backward()
    print(y, x.grad)

というコードを実行する。すると、

[1] Variable variable(1)
[2] Add 1
[3] Add 1
[4] Pow 27
[4] Pow 27 is added by 6 -> 33
variable(36) variable(33)

という感じになるので、まぁ、そんな感じかなと。

もっと詳しく見るために、以下のように変更する。

diff --git a/dezero/core.py b/dezero/core.py
index 04985c9..70be3c4 100644
--- a/dezero/core.py
+++ b/dezero/core.py
@@ -2,6 +2,7 @@ import weakref
 import numpy as np
 import contextlib
 import dezero
+import inspect
 
 
 # =============================================================================
@@ -38,6 +39,21 @@ try:
 except ImportError:
     array_types = (np.ndarray)
 
+class UniqueId:
+    def __init__(self):
+        self.ids = {}
+        self.id = 1
+
+    def __call__(self, obj):
+        hash_ = hash(obj)
+        if hash_ in self.ids:
+            return self.ids[hash_]
+        new_id = self.id
+        self.ids[hash_] = new_id
+        self.id += 1
+        return new_id
+
+get_id = UniqueId()
 
 class Variable:
     __array_priority__ = 200
@@ -47,6 +63,12 @@ class Variable:
             if not isinstance(data, array_types):
                 raise TypeError('{} is not supported'.format(type(data)))
 
+        if name == "tmp" or name == "ones_like":
+            print(f"[{get_id(self)}] '{data}' created", end="")
+            name = None
+        else:
+            print(f"[{get_id(self)}] '{data}' created")
+
         self.data = data
         self.name = name
         self.grad = None
@@ -91,7 +113,8 @@ class Variable:
     def backward(self, retain_grad=False, create_graph=False):
         if self.grad is None:
             xp = dezero.cuda.get_array_module(self.data)
-            self.grad = Variable(xp.ones_like(self.data))
+            self.grad = Variable(xp.ones_like(self.data), name="ones_like")
+            print(f" in backward {self.grad} ")
 
         funcs = []
         seen_set = set()
@@ -115,7 +138,9 @@ class Variable:
                 for x, gx in zip(f.inputs, gxs):
                     if x.grad is None:
                         x.grad = gx
+                        print(f"[{get_id(x)}] {f} {x.grad.data}")
                     else:
+                        print(f"[{get_id(x)}] {f} {x.grad.data} is added by {gx.data} -> {x.grad.data + gx.data}")
                         x.grad = x.grad + gx
 
                     if x.creator is not None:
@@ -188,7 +213,8 @@ class Function:
         ys = self.forward(*xs)
         if not isinstance(ys, tuple):
             ys = (ys,)
-        outputs = [Variable(as_array(y)) for y in ys]
+        outputs = [Variable(as_array(y), name="tmp") for y in ys]
+        print(f" by {inspect.stack()[1].function}")
 
         if Config.enable_backprop:
             self.generation = max([x.generation for x in inputs])
@@ -205,6 +231,9 @@ class Function:
     def backward(self, gys):
         raise NotImplementedError()
 
+    def __str__(self):
+        return self.__class__.__name__
+
 
 # =============================================================================
 # 四則演算 / 演算子のオーバーロード

すると以下のようなログになる。かなり多くの中間オブジェクトとしてのテンソルが作られていることが分かる。

[1] '3' created
[2] '9' created by pow
[3] '27' created by pow
[4] '36' created by add
[5] '1' created in backward variable(1) 
[2] Add 1
[3] Add 1
[6] '9' created by pow
[7] '3' created
[8] '27' created by mul
[7] '27' created by mul
[1] Pow 27
[6] '3' created by pow
[9] '2' created
[10] '6' created by mul
[9] '6' created by mul
[1] Pow 27 is added by 6 -> 33
[11] '33' created by add
variable(36) variable(33)