OpenAIのwhisperをpytorchからjaxに書き直して70倍速くなった (sanchit-gandhi/whisper-jax)というニュースでjaxに興味持ちました。 実のところ、pytorch→jaxの寄与分は、この70倍のうち2倍とのことなのですが、それでもかなりのパフォーマンスです。
まずはjaxの手配から。WSL2 or Dockerとも思いましたが、Windowsネイティブで実行を目指しました。
jaxのビルド(Windows)
jaxはWindows向けに公式バイナリが配布されておらず、自分でビルドする必要があります。少し前までコミュニティビルドバイナリがあったようなのですが、23年4月30日現在、jaxlib 0.3.17+CUDA 11.1などバージョンが古いものしか見当たりません。
ビルド環境
コンパイラのバージョンの組み合わせなどによってビルドが通ったり通らんかったりしそうなのでメモしておきます。
- google jax v0.3.24
- Bazel bazelisk-windows-amd64.exe v1.10.1
- miniconda
- python 3.10
- numpy 1.24.3
- Visual Studio 2019 Build Tools
- CUDA 11.7 + cudnn 8.5.0
- Win 11の設定で開発者モードを有効にする
- Bazelのバージョンは大事。bazeliskを使うと良いバージョンのbazelを選んでくれる。
- realpathなど、bash系のコマンドを導入すること。Git bash付属のものを利用可能。(msys2のScoopならC:\Users\{ユーザ名}\scoop\apps\msys2\2023-03-18\usr\binなど)
- jaxのクローンはできるだけドライブ直下に。パス長がギリギリになる。
- exFATのドライブを使うとSynbolicLinkを作れないのでエラーが出る
- サブコンポになってるTensorFlowなどがVC2022に対応してないかも (Issue #60062 · tensorflow/tensorflow). 必要に応じて環境変数BAZEL_VCを定義する(C:\Program Files(x86)\Microsoft Visual Studio\2019\BuildTools\VC)
WindowsをP3MV3(1.6 USD/hr)をSelf-Hostedするお金があったらMatrix作って確かめます!
Git bash付属のコマンド類
Bazelの公式ではmsys2と書かれているがgit bashと周辺ツールのほうが素性が良さそうで、このあたりが使えるよう環境変数の$env:path
に追加します。
- C:\Program Files\Git\cmd
- C:\Program Files\Git\mingw64\bin
- C:\Program Files\Git\usr\bin
追加後、特にBazelが使いたがるrealpath
が動作すれば良い。
あわせてBAZEL_SH=C:\Program Files\Git\usr\bin\bash.exe
としておく。
ビルド作業
環境を整えたあと、ビルド。condaはjax用に環境つくっておきました。
conda create -n jax python=3.10
conda activate jax
conda install numpy
cd d:/
git clone https://github.com/google/jax.git
cd jax
git checkout jaxlib-v0.3.24
python .\build\build.py --enable_cuda `
--cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" `
--cudnn_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" `
--cuda_version="11.7" --cudnn_version="8.5.0" --bazel_statup_options="--output_user_root=d:/tmp" `
--bazel_path="D:/bazel.exe"
ビルド時間はCore i7 10th Gen(8 Core)で3時間くらいというとこでしょうか。
途中、CUDAコードのビルド時に依存パッケージの文字コード警告が出て、その後エラーが出てしまうことがありました。
external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc(1): warning C4819: The file contains a character that cannot be represented in the current code page (932). Save the file in Unicode format to prevent data loss
...(中略)...
external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc(106): error C2065: 'ptxas_path': undeclared identifier
...(後略)...
こちらはコンパイラの警告 (レベル 1) C4819 | Microsoft Learnに従って、asm_compiler.ccをBOM付UTF8(さらに念のため改行コードをCRLFに変換)し、build.pyを実行するコマンドを繰り返すことで完了までたどり着きました。 当該ファイルに変な文字が入っているようには見えなかったので不思議です。
最後までビルドが通ると下記ログがでます。
C:\Users\chachay\miniconda3\envs\jax\lib\site-packages\setuptools\command\install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
warnings.warn(
C:\Users\chachay\miniconda3\envs\jax\lib\site-packages\wheel\bdist_wheel.py:83: RuntimeWarning: Config variable 'Py_DEBUG' is unset, Python ABI tag may be incorrect
if get_flag("Py_DEBUG", hasattr(sys, "gettotalrefcount"), warn=(impl == "cp")):
Output wheel: D:\jax\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl
To install the newly-built jaxlib wheel, run:
pip install D:\jax\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl
余談
github actionでバイナリ作ろうとしたらホストのメモリが足りずヒープエラーで強制終了したのですが、
- github actionでページングを有効化する(actions/configure-pages)
- BAZELの最大利用メモリを制限する.
--local_ram_resources=2048
(TF - Bazel Build options)
whl配布するならselfhosted serverが欲しくなります…。
インストール
完成品のjaxlibはd:/jax/dist
にあります。jaxやflaxとあわせてインストールします。
cd d:/jax/dist
conda activate jax
pip install flax==0.6.4 . .\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl
Bazelのキャッシュをきれいにするなら
bazel clean
bazel shutdown
jaxの試食
動作するか確認します。付属のサンプルスクリプトを走らせます。
python .\examples\kernel_lsq.py
MSE: 3.916308e-08
jaxpr of gram(linear_kernel):
{ lambda ; a:f32[100,20]. let
b:f32[100,100] = dot_general[
dimension_numbers=(((1,), (1,)), ((), ()))
precision=(<Precision.HIGH: 1>, <Precision.HIGH: 1>)
preferred_element_type=None
] a a
in (b,) }
jaxpr of gram(rbf_kernel):
{ lambda ; a:f32[100,20]. let
b:f32[100,1,20] = broadcast_in_dim[
broadcast_dimensions=(0, 2)
shape=(100, 1, 20)
] a
c:f32[1,100,20] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(1, 100, 20)
] a
d:f32[100,100,20] = sub b c
e:f32[100,100,20] = integer_pow[y=2] d
f:f32[100,100] = reduce_sum[axes=(2,)] e
g:f32[100,100] = neg f
h:f32[100,100] = exp g
in (h,) }
動いた! 寝る!
参考
- Additional Notes for Building jaxlib from source on Windows - JAX documentation
- [MSVC]Tensorflow failed to error C2678: binary '==': no operator found which takes a left-hand operand of type 'const _Ty' · Issue #60062 · tensorflow/tensorflow
- Installing Bazel on Windows
- Using Bazel on Windows
- Output Directory Layout | Bazel
- Building Bazel on a Windows machine - Stack Overflow
No comments:
Post a Comment