JAXとPyTorchの性能差を検証します。今回、ベンチマークはpytestのbenchmarkモジュールを使って行いました。
今回のコードはここに置いてあります。python_bench/neural_network at master · Chachay/python_bench
環境
- python 3.10
- numpy 1.24.3
- pytorch 2.0.0
- jax
- jaxlib v0.3.24
- jax 0.3.24
- flax 0.6.4
- CUDA 11.7 + cudnn 8.5.0
- NVIDIA Driver Version: 531.14
- NVIDIA RTX3060
あいにくWindowsではPyTorch2.0のJIT機能は使えません。
ライブラリの性能
AlexNetおよびGoogleNetで比較し、JAXはPyTorchの1.4~2.0倍の性能とわかりました。Whisper-jaxでPyTorchからJAXに書き換えた性能向上分2倍と同等です。また、JAXのチュートリアルでも、2.5~3.4倍の性能と紹介されており、妥当な結果と思われます。
性能差はAlexNetやVGGのような単純な2次元畳み込みよりも、GoogleNetやResNetのようなモデルのほうがつきやすいようです。
Network | ||
---|---|---|
Flax | PyTorch | |
AlexNet | 2.7 (1.0) | 3.8 (1.4) |
GoogleNet | 37.0 (1.0) | 81.1 (2.2) |
AlexNetの実装
推論(Inference)での性能差を計測するため、全結合層や学習に関する層を省略したAlexNetモデルを定義した。
import torch
from torch import nn
class AlexNetPyTorch(nn.Module):
def __init__(self):
super(AlexNetPyTorch, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(96, 256, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(256, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2)
)
def forward(self,x):
x = self.features(x)
return x
from flax import linen as nn
class AlexNetFlax(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(96, (11, 11), strides=(4, 4), name='conv1')(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2))
x = nn.Conv(256, (5, 5), padding=((2, 2),(2, 2)), name='conv2')(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2 ,2))
x = nn.Conv(384, (3, 3), padding=((1, 1),(1, 1)), name='conv3')(x)
x = nn.relu(x)
x = nn.Conv(384, (3, 3), padding=((1, 1),(1 ,1)), name='conv4')(x)
x = nn.relu(x)
x = nn.Conv(256, (3, 3), padding=((1, 1),(1, 1)), name='conv5')(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2))
return x
GoogleNet
こちらはGitHubのレポジトリをご参照ください。
ベンチのコード
ベンチはpytest-benchmarkを使って準備しました。イメージを共有するため簡略化したものを紹介します。
ベンチはrun(x)
を20回繰り返し実行(iterations)する計測を3セット(rounds)した結果を返します。
各測定の前に2回のwarmupを行います。
import pytest
import numpy as np
import torch
import jax
import jax.numpy as jnp
from models_flax import AlexNetFlax
from models_pytorch import AlexNetPyTorch
@pytest.mark.benchmark(
group="AlexNet",
warmup=True
)
def test_AlexNetPytorch(benchmark):
model = AlexNetPyTorch()
model.to('cuda')
model.eval()
# FlaxとPyTorchでデータの並び順が異なることに注意
# バッチ数、チャネル数、画像高さ、画像幅 [N, C, H, W]
x = np.random.rand(16, 3, 224, 224).astype(np.float32)
x = torch.from_numpy(x).to('cuda')
def run(_x):
with torch.no_grad():
return model(_x)
benchmark.pedantic(run, args=(x,), warmup_rounds=2, iterations=20, rounds=3)
@pytest.mark.benchmark(
group="AlexNet",
warmup=True
)
def test_AlexNetFlax(benchmark):
model = AlexNetFlax()
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
# FlaxとPyTorchでデータの並び順が異なることに注意
# バッチ数、画像高さ、画像幅、チャネル数 [N, H, W, C]
x = jax.random.normal(key1, (16, 224, 224, 3))
weight = model.init(key2, x) # Initialization cal
@jax.jit
def run(_x):
y = model.apply(weight, _x)
# JAXは非同期実行するのでベンチのため結果がでるのを待ちます。
jax.block_until_ready(y)
return y
# warm_upラウンドを2回いれることで、jitの時間を除外する
benchmark.pedantic(run, args=(x,), warmup_rounds=2, iterations=20, rounds=3)
if __name__ == "__main__":
pytest.main(['-v', __file__])
実行
全部の条件のベンチを行うコマンドはこちらです。
pytest benchmark_main.py --benchmark-compare
結果はこのようにグループで出力されます。
--------------------------------------------------------------------------- benchmark 'AlexNet': 2 tests ---------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_AlexNetFlax 2.7234 (1.0) 2.7352 (1.0) 2.7289 (1.0) 0.0059 (1.0) 2.7283 (1.0) 0.0088 (1.0) 1;0 366.4413 (1.0) 3 20
test_AlexNetPytorch 3.7357 (1.37) 3.9136 (1.43) 3.8180 (1.40) 0.0897 (15.15) 3.8047 (1.39) 0.1335 (15.09) 1;0 261.9172 (0.71) 3 20
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------- benchmark 'GoogleNet': 2 tests ----------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_GoogleNetFlax 36.6693 (1.0) 37.2125 (1.0) 36.9712 (1.0) 0.2766 (1.17) 37.0316 (1.0) 0.4074 (1.15) 1;0 27.0481 (1.0) 3 20
test_GoogleNetPytorch 80.8444 (2.20) 81.3176 (2.19) 81.0850 (2.19) 0.2367 (1.0) 81.0931 (2.19) 0.3549 (1.0) 1;0 12.3327 (0.46) 3 20
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
個別に実行する場合
# ひとつだけ
pytest .\benchmark_main.py::test_AlexNetPytorch
# 複数
pytest .\benchmark_main.py -k "test_GoogleNetPytorch or test_GoogleNetFlax"
No comments:
Post a Comment