自動微分を実装してみた

自動微分というアルゴリズムがある。チェーンルールを使うことで、予め微分の式を与えなくてもそれと同等の精度で微分ができるというもの。もともとの計算時間の定数倍しかかからず、かつloopが入ってたりするような微分の式を書くのに困るような計算相手でも平気なので重宝する。

http://www.kmonos.net/wlog/123.html#_2257111201

がとてもよい紹介。これは自動微分のうちでもフォワードモードというもので、ある一つの変数に関しての微分を計算出来るというやつだが、リバースモードといって多変数関数に対して一気に傾きが計算出来るというものもある。またこれを組み合わせればヘシアンをわりと素早く計算できる。
実装してみたのがこれ。
https://github.com/nos13/autodiff
他人が使うことを微塵も考えてないのでライブラリの体をあまりなしていないが…。

というか作ったのは少し前で、そろそろ記憶が風化しかけているのでメモ。

forward-mode

f(x+dx) = f(x) + dx*f'(x)
の関係を使う。

普通の数を表すxという値と、微小量を表すdxという値を持つdual型というのを作る。これをとりあえずdual(x,dx)と書こう。いろんなfに対してこのdual型の演算
f(dual(x,dx)) = dual(f(x), dx*f'(x))
の関係を、足し算とか三角関数とか基本的なfに対してばーっと定義しておく。
dual(x, dx) * dual(y, dy) = dual(x*y, x*dy + y*dx) とか、
sin(dual(x,dx)) = dual(sin(x), dx*cos(x)) とか。

するとこれらを組み合わせて作った関数であれば、例えばF(x,y,z)に対しては
F(x,dual(y,1),z) = dual(F(x,y,z), ∂F/∂y (x,y,z))
などのように微分を行うことができる。

ただし場合分けのある関数についてはその境界で微分が飛ぶような振る舞いをせず、左右どちらかのpiece上での微分値が返ってくることになるのでそこだけ注意。階段関数をこれで微分したのを数値積分しても元に戻りません。これはreverse-modeでも同じ。

common lispの普通の算術関数はオーバーロードできるようには作られていないので今回はcl-generic-arithmeticというパッケージを使ってこれができるようにした。

reverse-mode

forward-modeでは一変数についての微分を行うことができるが、reverse-modeというのを使うとワンパスで多変数関数の全引数についての微分、つまりgradientを求めることができる。

これには計算しながら各関数がノードでその引数が葉になるような’計算木'を作っていって、最後に頂点から最初の引数へと遡っていくということをやる。

例えば求めたい式が F(x,y) = f(g(x,y), h(x,y)) とかだとすると、

     f
    /\
   g  h
   |\/|
   |/\|
   x  y

という木になる(x,yはそれぞれgとh両方の葉になっている)。x,y,g(x,y),h(x,y),f(g(x,y),h(x,y))をそれぞれx_1,x_2,x_3,x_4,x_5、それに対応してg,h,fをf_3,f_4,f_5としよう。依存関係が添字の順になってるので[tex: \frac{\partial f_i}{\partial x_j} = 0 (ij} \bar x_i \frac{\partial f_i}{\partial x_j}]

だ。これを使って\bar x_5(これは定義から1)から遡って行ける。
遡り終えると(\bar x_1,\bar x_2) = \nabla F(x,y)が計算出来ているという寸法。

遡るために最初計算しながら計算木を記録しておく必要があるのでメモリが必要になるが、かわりに変数の数が多くても素早く傾きの計算を行える。

ヘッシアンの計算

forward-modeとreverse-modeは実装としてはわりと直交しているので、例えば1番目の変数についてforward-modeをするのと同時にbackward-modeを行うことで、ヘッシアンの第一列
(\frac{\partial^2 f}{\partial x_1^2},\frac{\partial^2 f}{\partial x_1 \partial x_2},\frac{\partial^2 f}{\partial x_1 \partial x_3},\ldots,\frac{\partial^2 f}{\partial x_1 \partial x_n})
を一度に求めることができる。これをn回繰り返すことでヘッシアン全体が求まる。

その他

defgenericは遅いので、ちょっとした規模の数値計算に使おうと思うとちょっと厳しい。common lispはこういうオーバーローディングのようなアプローチにはあまり向いてないかもしれない。cl-autodiff https://github.com/masonium/cl-autodiff というライブラリはad-defunマクロに渡された定義を変形してゆくといういかにもlisp的な方法を採用しているのでそれを参考にした方がよいかもしれない。他の言語ではソースコードを書き換えることに相当し(しかしそういうライブラリもいくつかあるようだ)、いかにも面倒くさそうだが、lispに関してはむしろその方が向いていそう。