例によって自分用のメモ書き
#誰かが幸せになるとそれもいいよね
説明する必要もないかもしれないが、JAXはGoogleが出している比較的低レベルのライブラリである。
Autograd and XLAと書かれているように自動微分と高速な行列演算が出来るが、それで何が嬉しいかというと主に機械学習分野である。もちろん流体計算などのシミュレーション分野に利用されてもいるので幅広く使われている模様。
深層学習用にはJAXそのままではなく更に一段ライブラリを挟むのが主流である。(そういう意味で比較的低レベルと称した)
で、公にはWindowsバイナリは配布されていないが
【高速なnumpy】WindowsでJAXを使おう【Python】 - Qiita
こういう記事が出ている。
つまり、ここにバイナリがある。
https://whls.blob.core.windows.net/unstable/index.html
適当なものをダウンロードすれば難なく入る。
CUDA使う人は事前に入れましょう。
今のところ不具合もないのでそのうち公式にも正式対応してバイナリ配布されるんじゃないかと甘く見ている。
いくつかのサンプルも問題なく動くが完全に自作のコードでパフォーマンスを出すのはそんなに簡単なことじゃない感じ。
また、以下に囲碁の強化学習を実装されたものがあったので試しに実行してみた。
GitHub - NTT123/a0-jax: AlphaZero in JAX
RTX3090で最初のステップが2時間半程度と出たので、200ステップで500時間かぁ。
3週間コースだな。と適当に思っていたがこれは自己対局分で学習部分を入れるとほぼ倍になる。
(直感的には学習時間が長すぎる気がする)
加えて、メモリ128GB程度ではスワップ始めて数ステップで全然進まなくなることが分かった。
過激なサンプルプログラムはどのくらいリソース使うかちょっと書いておいて欲しいね。