Win11でJAX!

OpenAIのwhisperをpytorchからjaxに書き直して70倍速くなった (sanchit-gandhi/whisper-jax)というニュースでjaxに興味持ちました。 実のところ、pytorch→jaxの寄与分は、この70倍のうち2倍とのことなのですが、それでもかなりのパフォーマンスです。

まずはjaxの手配から。WSL2 or Dockerとも思いましたが、Windowsネイティブで実行を目指しました。

jaxのビルド(Windows)

jaxはWindows向けに公式バイナリが配布されておらず、自分でビルドする必要があります。少し前までコミュニティビルドバイナリがあったようなのですが、23年4月30日現在、jaxlib 0.3.17+CUDA 11.1などバージョンが古いものしか見当たりません。

ビルド環境

コンパイラのバージョンの組み合わせなどによってビルドが通ったり通らんかったりしそうなのでメモしておきます。

このほか大事な点
  • Win 11の設定で開発者モードを有効にする
  • Bazelのバージョンは大事。bazeliskを使うと良いバージョンのbazelを選んでくれる。
  • realpathなど、bash系のコマンドを導入すること。Git bash付属のものを利用可能。(msys2のScoopならC:\Users\{ユーザ名}\scoop\apps\msys2\2023-03-18\usr\binなど)
  • jaxのクローンはできるだけドライブ直下に。パス長がギリギリになる。
  • exFATのドライブを使うとSynbolicLinkを作れないのでエラーが出る
  • サブコンポになってるTensorFlowなどがVC2022に対応してないかも (Issue #60062 · tensorflow/tensorflow). 必要に応じて環境変数BAZEL_VCを定義する(C:\Program Files(x86)\Microsoft Visual Studio\2019\BuildTools\VC)
私が確認した範囲ではjaxlib v0.3.24+ CUDA 11.7がWindowsでビルドできる最新の組み合わせでした。v0.4.7やCUDA 12.1はダメそう。

WindowsをP3MV3(1.6 USD/hr)をSelf-Hostedするお金があったらMatrix作って確かめます!

Git bash付属のコマンド類

Bazelの公式ではmsys2と書かれているがgit bashと周辺ツールのほうが素性が良さそうで、このあたりが使えるよう環境変数の$env:pathに追加します。

  • C:\Program Files\Git\cmd
  • C:\Program Files\Git\mingw64\bin
  • C:\Program Files\Git\usr\bin

追加後、特にBazelが使いたがるrealpathが動作すれば良い。

あわせてBAZEL_SH=C:\Program Files\Git\usr\bin\bash.exeとしておく。

ビルド作業

環境を整えたあと、ビルド。condaはjax用に環境つくっておきました。

conda create -n jax python=3.10
conda activate jax
conda install numpy

cd d:/
git clone https://github.com/google/jax.git
cd jax
git checkout jaxlib-v0.3.24
python .\build\build.py --enable_cuda `
  --cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" `
  --cudnn_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" `
  --cuda_version="11.7" --cudnn_version="8.5.0" --bazel_statup_options="--output_user_root=d:/tmp" `
  --bazel_path="D:/bazel.exe"

ビルド時間はCore i7 10th Gen(8 Core)で3時間くらいというとこでしょうか。

途中、CUDAコードのビルド時に依存パッケージの文字コード警告が出て、その後エラーが出てしまうことがありました。

external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc(1): warning C4819: The file contains a character that cannot be represented in the current code page (932). Save the file in Unicode format to prevent data loss
...(中略)...
external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc(106): error C2065: 'ptxas_path': undeclared identifier
...(後略)...

こちらはコンパイラの警告 (レベル 1) C4819 | Microsoft Learnに従って、asm_compiler.ccをBOM付UTF8(さらに念のため改行コードをCRLFに変換)し、build.pyを実行するコマンドを繰り返すことで完了までたどり着きました。 当該ファイルに変な文字が入っているようには見えなかったので不思議です。

最後までビルドが通ると下記ログがでます。

C:\Users\chachay\miniconda3\envs\jax\lib\site-packages\setuptools\command\install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
warnings.warn(
C:\Users\chachay\miniconda3\envs\jax\lib\site-packages\wheel\bdist_wheel.py:83: RuntimeWarning: Config variable 'Py_DEBUG' is unset, Python ABI tag may be incorrect
if get_flag("Py_DEBUG", hasattr(sys, "gettotalrefcount"), warn=(impl == "cp")):
Output wheel: D:\jax\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl

To install the newly-built jaxlib wheel, run:
  pip install D:\jax\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl

余談

github actionでバイナリ作ろうとしたらホストのメモリが足りずヒープエラーで強制終了したのですが、

  1. github actionでページングを有効化する(actions/configure-pages)
  2. BAZELの最大利用メモリを制限する.--local_ram_resources=2048(TF - Bazel Build options)
といった方法で解決できるそうです。ちなみに1を使いました。ただ、Github actionが360分でタイムアウトするので工夫が必要かなと思います。

whl配布するならselfhosted serverが欲しくなります…。

インストール

完成品のjaxlibはd:/jax/distにあります。jaxやflaxとあわせてインストールします。

cd d:/jax/dist
conda activate jax
pip install flax==0.6.4 . .\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl

Bazelのキャッシュをきれいにするなら

bazel clean
bazel shutdown

jaxの試食

動作するか確認します。付属のサンプルスクリプトを走らせます。

python .\examples\kernel_lsq.py
MSE: 3.916308e-08

jaxpr of gram(linear_kernel):
{ lambda ; a:f32[100,20]. let
    b:f32[100,100] = dot_general[
      dimension_numbers=(((1,), (1,)), ((), ()))
      precision=(<Precision.HIGH: 1>, <Precision.HIGH: 1>)
      preferred_element_type=None
    ] a a
  in (b,) }

jaxpr of gram(rbf_kernel):
{ lambda ; a:f32[100,20]. let
    b:f32[100,1,20] = broadcast_in_dim[
      broadcast_dimensions=(0, 2)
      shape=(100, 1, 20)
    ] a
    c:f32[1,100,20] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 100, 20)
    ] a
    d:f32[100,100,20] = sub b c
    e:f32[100,100,20] = integer_pow[y=2] d
    f:f32[100,100] = reduce_sum[axes=(2,)] e
    g:f32[100,100] = neg f
    h:f32[100,100] = exp g
  in (h,) }

動いた! 寝る!

参考

No comments:

Post a Comment