物理シミュレーションと数値積分

最近、グランツーリスモ7のプレイヤーから、物理シミュレーションが破綻する現象が報告されています。特にサンババスで極端な破綻が起きるとのことで、空中に60万km/h以上で打ち出される事例の報告までされています。 報告を見る限り、リアにバラストを積み、車高を下げることで発生するようです。

一般に物理エンジンは数値積分を内部的に使っています。物理エンジンでシミュレーションする系に固いレートのバネが入っていると計算が発散しやすく、 GT7で報告されているような車が宙に突如放り出されるような挙動に陥りやすくなります。 GT7でも同じことが起きているとまでは断言できませんが、簡単なモデルを使って「再現」をしてみましょう。

サスペンションモデル

車両のサスペンションを含む最も単純な物理運動表現として、車体をタイヤとサスペンションに単純化したモデルを導入します。

サスペンションモデル

このモデルでは、タイヤのバネ定数 $ k_1 \text{\(N/m\)} $ や、ハブを起点に評価したホイールレート $ k_2 \text{\(N/m\)} $ を代表的なパラメータとして取り扱います。

システム方程式は

$$ \frac{d}{dt} \begin{pmatrix} x_1 \\ x_2 \\ \dot{x}_1 \\ \dot{x}_2 \end{pmatrix} = \begin{pmatrix} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ -\frac{k_1 + k_2}{m_t} & \frac{k_2}{m_t} & -\frac{c_2}{m_t} & \frac{c_2}{m_t} \\ \frac{k_2}{m_b} & -\frac{k_2}{m_b} & \frac{c_2}{m_b} & -\frac{c_2}{m_b} \end{pmatrix} \begin{pmatrix} x_1 \\ x_2 \\ \dot{x}_1 \\ \dot{x}_2 \end{pmatrix} $$

GT7で発散が報告されているパラメータを元に下記の式からシミュレーションで使う仮の値を概算で設定します。

  • 車体重量 $ m_b $ : 1,000 kg
  • タイヤ総重量 $ m_t $ : 100 kg
  • サスの固有振動数 $ f_2 $ : 1.5 Hz
  • サスの減衰比 $ \zeta_2 $ : 0.25

このほか、タイヤのバネ定数 $ k_1 $ をおよそ $ 2.0 \cdot 10^6 \text{N/m} $ とします。

システム方程式で使うパラメータはバネマスダンパー系の特性値の公式を使って概算値として求めます。

$$ \begin{aligned} k_2 &= m_b (2 \pi f_2)^2 \\ c_2 &= 2 m_b \zeta_2 (2 \pi f_2) \end{aligned} $$

物理シミュレーション

物理シミュレーションでは、現実世界の物理現象を数値的に再現するために、微分方程式を解く必要があります。これに数値積分法が使用されます。原理が簡単な前進オイラー法と、物理シミュレーションエンジンでポピュラーな半陰オイラー法を試します。

物理モデルが行列 $ A \in \mathrm{GL}_{2n}(\mathbb{R}) $ を使って、微分方程式

$$ \dot{x} = \frac{d}{dt} \begin{pmatrix} X \\ V \end{pmatrix} = A x $$

で与えられたとき、数値積分で逐次的にステップ $ n $ の $ x_n $ を求めます。

本件のモデルに対しては

$$ A = \begin{pmatrix} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ -2 \cdot 10^4 & 8.9 \cdot 10^2 & -4.7 \cdot 10 & 4.7 \cdot 10 \\ 8.9 \cdot 10 & -8.9 \cdot 10 & 4.7 & -4.7 \\ \end{pmatrix} $$

です。

前進オイラー法

前進オイラー法は、次の状態を現在の状態からタイムステップを使って計算するシンプルな数値積分法です。数式で表すと次のようになります。

$$ x_{n+1} = x_n + \Delta t A x_n $$

先程のサスペンションモデルをこの数値積分で解いてみましょう。

通常、ゲーム向けのシミュレーションではシミュレーションのタイムステップと画面のリフレッシュレートを一致させますから、GT7のリフレッシュレート120 Hzをひとつの水準にします。また、対比のため、1kHzでのシミュレーションもしてみましょう。

初期値として車両が加速トルクをうけて、準静的に1 cmスクワットし、これが開放された瞬間を取ります ($ x_0^\intercal= [0,-0.01,0,0] $ )。

前進オイラー法によるシミュレーション結果

120Hzの水準では車体の振幅が次第に拡大していき、ほんの3秒の間にはね飛ぶようになります(グラフ上段)。1kHzの水準ではダンパーの効果もあり、振幅がゼロに収束します。厳密解(実車体の動き)はこちらに近いはずです。

グラフ下段に示した通り、120Hzのシミュレーション条件ではシステムへ仕事の供給がないのに系のエネルギーが指数的に増えていきます。

半陰オイラー法

半陰オイラー法は、前進オイラー法の安定性を改善するため、速度の更新結果に基づいて位置を更新します。

$$ \begin{aligned} V_{n+1} &= V_n + \Delta t \dot{V_n} \\ X_{n+1} &= X_n + \Delta t V_{n+1} \\ &= X_n + \Delta t V_n + (\Delta t)^2 \dot{V_n} \end{aligned} $$

これを折り込み、前進オイラー法と同様に1ステップで計算するには…

$$ x_{n+1} = x_n + \Delta t (I + \Delta t \begin{pmatrix} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \end{pmatrix} ) A x_n $$

この式で下のように $ \tilde{A} $ を取れば実装上は前進オイラー法と同じです。

$$ \tilde{A} = (I + \Delta t \begin{pmatrix} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 \end{pmatrix} ) A $$

これを使って修正した行列で同様にサスペンションモデルをシミュレーションすると120Hzでも発散しないことがわかります。

半陰(semi-implicit)オイラー法の結果(緑)

本シミュレーションの条件では発生しませんでしたが、より厳しいパラメータを与えると、この半陰オイラー法を採用しても発散します。GT7も、一般の物理シミュレーション同様に、この数値積分法を使っていると考えられますので、 剛体同士の接触などが起きていたのだろうと思います。

数値積分の安定域

シンプルな常微分方程式

$$ \begin{aligned} \dot{y} &= \lambda y \\ y(0) &= 1 \\ \lambda & \in \{ \text{Re}(x)<0 | x \in \mathbb{C}\} \end{aligned} $$

を考えます。この常微分方程式の解は、$ y = exp(\lambda t) $ です。この方程式を前進オイラー法で解くことを考えると、タイムステップ $ \Delta t $ に対して、数値解は

$$ \begin{aligned} y_{n+1} &= y_n + \Delta t \lambda y_n \\ &= (1+\Delta t \lambda) y_n \\ &= (1+\Delta t \lambda)^n y_0. \end{aligned} $$

このとき、$ \lambda $ の定義から $ t \rightarrow \infty $ で $ y = 0 $ ですから、数値解の収束条件は $ | 1 + \Delta t \lambda | < 1 $ とわかります。

同様に

$$ B = \begin{pmatrix} \lambda_{1} & & \\ & \ddots & \\ & & \lambda_{2n} \end{pmatrix} $$

を考え、$ \dot{y} = By $ について考察すると、すべての $ \lambda_k $ が前記の収束条件を満たしている必要性があります。

また、仮に$ A $ が $ B = P^{-1}AP $ と対角化可能であれば、$ B $ の対角に並ぶ要素は $ A $ の固有値です。数値積分の収束について議論する場合は、Aの固有値に注目すれば良さそうです。

本ケースの考察

サスペンションモデルの $ A $ の固有値を求めると $ -23.7±141.8i, -2.2±9.0i $であることから、120Hz ($ \Delta t = \frac{1}{120} $) の条件で、(絶対値が大きい方の固有値で) $ | 1+ \frac{-23.7+141.8i}{120}| > 1 $ となり、収束条件を満たしません。

前進オイラー法では、収束条件を満たすには $ \Delta t < \frac{1}{433} $ (433 Hz以上)である必要があります。実際、1kHzの条件では収束していました。

次に半陰オイラー法について $ \tilde{A} $ の固有値を求めると $ -110.8±91.7i, -2.5±9.0i $です。収束条件は $ \tilde{A} $ について、前進オイラー法と同じ議論は成り立つことから、収束条件を確認すると、 $ | 1+ \frac{-110.8+91.7i}{120}| \approx 0.77 < 1 $ で、120Hzの数値積分でも収束条件を満たすことがわかります。

参考

JAXとPyTorchの速度検証

JAXとPyTorchの性能差を検証します。今回、ベンチマークはpytestのbenchmarkモジュールを使って行いました。

今回のコードはここに置いてあります。python_bench/neural_network at master · Chachay/python_bench

環境

  • python 3.10
  • numpy 1.24.3
  • pytorch 2.0.0
  • jax
  • CUDA 11.7 + cudnn 8.5.0
  • NVIDIA Driver Version: 531.14
  • NVIDIA RTX3060

あいにくWindowsではPyTorch2.0のJIT機能は使えません。

ライブラリの性能

AlexNetおよびGoogleNetで比較し、JAXはPyTorchの1.4~2.0倍の性能とわかりました。Whisper-jaxでPyTorchからJAXに書き換えた性能向上分2倍と同等です。また、JAXのチュートリアルでも、2.5~3.4倍の性能と紹介されており、妥当な結果と思われます。

性能差はAlexNetやVGGのような単純な2次元畳み込みよりも、GoogleNetやResNetのようなモデルのほうがつきやすいようです。

両ライブラリ推論性能(ms)
Network
Flax PyTorch
AlexNet 2.7 (1.0) 3.8 (1.4)
GoogleNet 37.0 (1.0) 81.1 (2.2)

AlexNetの実装

推論(Inference)での性能差を計測するため、全結合層や学習に関する層を省略したAlexNetモデルを定義した。

import torch
from torch import nn

class AlexNetPyTorch(nn.Module):
    def __init__(self):
        super(AlexNetPyTorch, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

    def forward(self,x):
        x = self.features(x)
        return x
from flax import linen as nn

class AlexNetFlax(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(96, (11, 11), strides=(4, 4), name='conv1')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))

        x = nn.Conv(256, (5, 5), padding=((2, 2),(2, 2)), name='conv2')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2 ,2))

        x = nn.Conv(384, (3, 3), padding=((1, 1),(1, 1)), name='conv3')(x)
        x = nn.relu(x)

        x = nn.Conv(384, (3, 3), padding=((1, 1),(1 ,1)), name='conv4')(x)
        x = nn.relu(x)

        x = nn.Conv(256, (3, 3), padding=((1, 1),(1, 1)), name='conv5')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))

        return x

GoogleNet

こちらはGitHubのレポジトリをご参照ください。

ベンチのコード

ベンチはpytest-benchmarkを使って準備しました。イメージを共有するため簡略化したものを紹介します。 ベンチはrun(x)を20回繰り返し実行(iterations)する計測を3セット(rounds)した結果を返します。 各測定の前に2回のwarmupを行います。

import pytest
import numpy as np

import torch
import jax
import jax.numpy as jnp

from models_flax import AlexNetFlax
from models_pytorch import AlexNetPyTorch

@pytest.mark.benchmark(
    group="AlexNet",
    warmup=True
)
def test_AlexNetPytorch(benchmark):
    model = AlexNetPyTorch()
    model.to('cuda')
    model.eval()

    # FlaxとPyTorchでデータの並び順が異なることに注意
    # バッチ数、チャネル数、画像高さ、画像幅 [N, C, H, W]
    x = np.random.rand(16, 3, 224, 224).astype(np.float32)
    x = torch.from_numpy(x).to('cuda')

    def run(_x):
        with torch.no_grad():
            return model(_x)

    benchmark.pedantic(run, args=(x,), warmup_rounds=2, iterations=20, rounds=3)

@pytest.mark.benchmark(
    group="AlexNet",
    warmup=True
)
def test_AlexNetFlax(benchmark):
    model = AlexNetFlax()

    key1, key2 = jax.random.split(jax.random.PRNGKey(0))

    # FlaxとPyTorchでデータの並び順が異なることに注意
    # バッチ数、画像高さ、画像幅、チャネル数 [N, H, W, C]
    x = jax.random.normal(key1, (16, 224, 224, 3))
    weight = model.init(key2, x) # Initialization cal

    @jax.jit
    def run(_x):
        y = model.apply(weight, _x)
        # JAXは非同期実行するのでベンチのため結果がでるのを待ちます。
        jax.block_until_ready(y)
        return y

    # warm_upラウンドを2回いれることで、jitの時間を除外する
    benchmark.pedantic(run, args=(x,), warmup_rounds=2, iterations=20, rounds=3)

if __name__ == "__main__":
    pytest.main(['-v', __file__])

実行

全部の条件のベンチを行うコマンドはこちらです。

pytest benchmark_main.py --benchmark-compare

結果はこのようにグループで出力されます。

--------------------------------------------------------------------------- benchmark 'AlexNet': 2 tests ---------------------------------------------------------------------------
Name (time in ms)          Min               Max              Mean            StdDev            Median               IQR            Outliers       OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_AlexNetFlax        2.7234 (1.0)      2.7352 (1.0)      2.7289 (1.0)      0.0059 (1.0)      2.7283 (1.0)      0.0088 (1.0)           1;0  366.4413 (1.0)           3          20
test_AlexNetPytorch     3.7357 (1.37)     3.9136 (1.43)     3.8180 (1.40)     0.0897 (15.15)    3.8047 (1.39)     0.1335 (15.09)         1;0  261.9172 (0.71)          3          20
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------- benchmark 'GoogleNet': 2 tests ----------------------------------------------------------------------------
Name (time in ms)             Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_GoogleNetFlax        36.6693 (1.0)      37.2125 (1.0)      36.9712 (1.0)      0.2766 (1.17)     37.0316 (1.0)      0.4074 (1.15)          1;0  27.0481 (1.0)           3          20
test_GoogleNetPytorch     80.8444 (2.20)     81.3176 (2.19)     81.0850 (2.19)     0.2367 (1.0)      81.0931 (2.19)     0.3549 (1.0)           1;0  12.3327 (0.46)          3          20
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

個別に実行する場合

# ひとつだけ
pytest .\benchmark_main.py::test_AlexNetPytorch
# 複数
pytest .\benchmark_main.py -k "test_GoogleNetPytorch or test_GoogleNetFlax"

参考

Win11でJAX!

OpenAIのwhisperをpytorchからjaxに書き直して70倍速くなった (sanchit-gandhi/whisper-jax)というニュースでjaxに興味持ちました。 実のところ、pytorch→jaxの寄与分は、この70倍のうち2倍とのことなのですが、それでもかなりのパフォーマンスです。

まずはjaxの手配から。WSL2 or Dockerとも思いましたが、Windowsネイティブで実行を目指しました。

jaxのビルド(Windows)

jaxはWindows向けに公式バイナリが配布されておらず、自分でビルドする必要があります。少し前までコミュニティビルドバイナリがあったようなのですが、23年4月30日現在、jaxlib 0.3.17+CUDA 11.1などバージョンが古いものしか見当たりません。

ビルド環境

コンパイラのバージョンの組み合わせなどによってビルドが通ったり通らんかったりしそうなのでメモしておきます。

このほか大事な点
  • Win 11の設定で開発者モードを有効にする
  • Bazelのバージョンは大事。bazeliskを使うと良いバージョンのbazelを選んでくれる。
  • realpathなど、bash系のコマンドを導入すること。Git bash付属のものを利用可能。(msys2のScoopならC:\Users\{ユーザ名}\scoop\apps\msys2\2023-03-18\usr\binなど)
  • jaxのクローンはできるだけドライブ直下に。パス長がギリギリになる。
  • exFATのドライブを使うとSynbolicLinkを作れないのでエラーが出る
  • サブコンポになってるTensorFlowなどがVC2022に対応してないかも (Issue #60062 · tensorflow/tensorflow). 必要に応じて環境変数BAZEL_VCを定義する(C:\Program Files(x86)\Microsoft Visual Studio\2019\BuildTools\VC)
私が確認した範囲ではjaxlib v0.3.24+ CUDA 11.7がWindowsでビルドできる最新の組み合わせでした。v0.4.7やCUDA 12.1はダメそう。

WindowsをP3MV3(1.6 USD/hr)をSelf-Hostedするお金があったらMatrix作って確かめます!

Git bash付属のコマンド類

Bazelの公式ではmsys2と書かれているがgit bashと周辺ツールのほうが素性が良さそうで、このあたりが使えるよう環境変数の$env:pathに追加します。

  • C:\Program Files\Git\cmd
  • C:\Program Files\Git\mingw64\bin
  • C:\Program Files\Git\usr\bin

追加後、特にBazelが使いたがるrealpathが動作すれば良い。

あわせてBAZEL_SH=C:\Program Files\Git\usr\bin\bash.exeとしておく。

ビルド作業

環境を整えたあと、ビルド。condaはjax用に環境つくっておきました。

conda create -n jax python=3.10
conda activate jax
conda install numpy

cd d:/
git clone https://github.com/google/jax.git
cd jax
git checkout jaxlib-v0.3.24
python .\build\build.py --enable_cuda `
  --cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" `
  --cudnn_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" `
  --cuda_version="11.7" --cudnn_version="8.5.0" --bazel_statup_options="--output_user_root=d:/tmp" `
  --bazel_path="D:/bazel.exe"

ビルド時間はCore i7 10th Gen(8 Core)で3時間くらいというとこでしょうか。

途中、CUDAコードのビルド時に依存パッケージの文字コード警告が出て、その後エラーが出てしまうことがありました。

external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc(1): warning C4819: The file contains a character that cannot be represented in the current code page (932). Save the file in Unicode format to prevent data loss
...(中略)...
external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc(106): error C2065: 'ptxas_path': undeclared identifier
...(後略)...

こちらはコンパイラの警告 (レベル 1) C4819 | Microsoft Learnに従って、asm_compiler.ccをBOM付UTF8(さらに念のため改行コードをCRLFに変換)し、build.pyを実行するコマンドを繰り返すことで完了までたどり着きました。 当該ファイルに変な文字が入っているようには見えなかったので不思議です。

最後までビルドが通ると下記ログがでます。

C:\Users\chachay\miniconda3\envs\jax\lib\site-packages\setuptools\command\install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
warnings.warn(
C:\Users\chachay\miniconda3\envs\jax\lib\site-packages\wheel\bdist_wheel.py:83: RuntimeWarning: Config variable 'Py_DEBUG' is unset, Python ABI tag may be incorrect
if get_flag("Py_DEBUG", hasattr(sys, "gettotalrefcount"), warn=(impl == "cp")):
Output wheel: D:\jax\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl

To install the newly-built jaxlib wheel, run:
  pip install D:\jax\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl

余談

github actionでバイナリ作ろうとしたらホストのメモリが足りずヒープエラーで強制終了したのですが、

  1. github actionでページングを有効化する(actions/configure-pages)
  2. BAZELの最大利用メモリを制限する.--local_ram_resources=2048(TF - Bazel Build options)
といった方法で解決できるそうです。ちなみに1を使いました。ただ、Github actionが360分でタイムアウトするので工夫が必要かなと思います。

whl配布するならselfhosted serverが欲しくなります…。

インストール

完成品のjaxlibはd:/jax/distにあります。jaxやflaxとあわせてインストールします。

cd d:/jax/dist
conda activate jax
pip install flax==0.6.4 . .\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl

Bazelのキャッシュをきれいにするなら

bazel clean
bazel shutdown

jaxの試食

動作するか確認します。付属のサンプルスクリプトを走らせます。

python .\examples\kernel_lsq.py
MSE: 3.916308e-08

jaxpr of gram(linear_kernel):
{ lambda ; a:f32[100,20]. let
    b:f32[100,100] = dot_general[
      dimension_numbers=(((1,), (1,)), ((), ()))
      precision=(<Precision.HIGH: 1>, <Precision.HIGH: 1>)
      preferred_element_type=None
    ] a a
  in (b,) }

jaxpr of gram(rbf_kernel):
{ lambda ; a:f32[100,20]. let
    b:f32[100,1,20] = broadcast_in_dim[
      broadcast_dimensions=(0, 2)
      shape=(100, 1, 20)
    ] a
    c:f32[1,100,20] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 100, 20)
    ] a
    d:f32[100,100,20] = sub b c
    e:f32[100,100,20] = integer_pow[y=2] d
    f:f32[100,100] = reduce_sum[axes=(2,)] e
    g:f32[100,100] = neg f
    h:f32[100,100] = exp g
  in (h,) }

動いた! 寝る!

参考