らんだむな記憶

blogというものを体験してみようか!的なー

JAX

自動微分に特化した NumPy とか言われる JAX。https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb として Colab 用のサンプルがあった。

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

The slowest run took 55.35 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 29.1 ms per loop

import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

10 loops, best of 5: 98.1 ms per loop

ループ云々は分からないけど、速い・・・らしい。非同期処理で計算しているから準備ができるまでブロックするので、.block_until_ready() を呼ぶとのこと。呼ばなくても動くのだが、呼び忘れるとどうなるのか少し気になる。

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25 0.19661194 0.10499357]

を見てもなかなかピンとこない。

$$
\begin{align*}
f(x) = \sum_{i=1}^n \frac{1}{1 + \exp(-x_i)}
\end{align*}
$$

ということなので、

$$
\begin{align*}
\frac{\partial f}{\partial x_i}(x) = - \frac{\exp(-x_i)}{(1 + \exp(-x_i))^2}
\end{align*}
$$

となる。よって、上記のコードは順に

$$
\begin{align*}
- \frac{1}{4}, - \frac{e^{-1}}{(1 + e^{-1})^2}, - \frac{e^{-2}}{(1 + e^{-2})^2}
\end{align*}
$$

となる。後ろ 2 つは e = 2.71828 からの手計算はしんどいので、numpy で計算すると

1/np.e/(1+1/np.e)**2, 1/np.e**2/(1+1/np.e**2)**2

(0.19661193324148188, 0.1049935854035065)

となって、先の値とよく一致している。

後は、まぁ JAX Quickstart — JAX documentation とか読んでいけば良いのだろう。PFN の CuPy とか思い出すけど、他にも Numba: A High Performance Python Compiler というのもあるということ。自分であれこれ高速な計算ロジックを考える時代は終わってしまったなと感じる。

とりあえず巨大な行列を操作したい時に高速ですよと。