JAXとPyTorchの速度検証

JAXとPyTorchの性能差を検証します。今回、ベンチマークはpytestのbenchmarkモジュールを使って行いました。

今回のコードはここに置いてあります。python_bench/neural_network at master · Chachay/python_bench

環境

  • python 3.10
  • numpy 1.24.3
  • pytorch 2.0.0
  • jax
  • CUDA 11.7 + cudnn 8.5.0
  • NVIDIA Driver Version: 531.14
  • NVIDIA RTX3060

あいにくWindowsではPyTorch2.0のJIT機能は使えません。

ライブラリの性能

AlexNetおよびGoogleNetで比較し、JAXはPyTorchの1.4~2.0倍の性能とわかりました。Whisper-jaxでPyTorchからJAXに書き換えた性能向上分2倍と同等です。また、JAXのチュートリアルでも、2.5~3.4倍の性能と紹介されており、妥当な結果と思われます。

性能差はAlexNetやVGGのような単純な2次元畳み込みよりも、GoogleNetやResNetのようなモデルのほうがつきやすいようです。

両ライブラリ推論性能(ms)
Network
Flax PyTorch
AlexNet 2.7 (1.0) 3.8 (1.4)
GoogleNet 37.0 (1.0) 81.1 (2.2)

AlexNetの実装

推論(Inference)での性能差を計測するため、全結合層や学習に関する層を省略したAlexNetモデルを定義した。

import torch
from torch import nn

class AlexNetPyTorch(nn.Module):
    def __init__(self):
        super(AlexNetPyTorch, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

    def forward(self,x):
        x = self.features(x)
        return x
from flax import linen as nn

class AlexNetFlax(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(96, (11, 11), strides=(4, 4), name='conv1')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))

        x = nn.Conv(256, (5, 5), padding=((2, 2),(2, 2)), name='conv2')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2 ,2))

        x = nn.Conv(384, (3, 3), padding=((1, 1),(1, 1)), name='conv3')(x)
        x = nn.relu(x)

        x = nn.Conv(384, (3, 3), padding=((1, 1),(1 ,1)), name='conv4')(x)
        x = nn.relu(x)

        x = nn.Conv(256, (3, 3), padding=((1, 1),(1, 1)), name='conv5')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))

        return x

GoogleNet

こちらはGitHubのレポジトリをご参照ください。

ベンチのコード

ベンチはpytest-benchmarkを使って準備しました。イメージを共有するため簡略化したものを紹介します。 ベンチはrun(x)を20回繰り返し実行(iterations)する計測を3セット(rounds)した結果を返します。 各測定の前に2回のwarmupを行います。

import pytest
import numpy as np

import torch
import jax
import jax.numpy as jnp

from models_flax import AlexNetFlax
from models_pytorch import AlexNetPyTorch

@pytest.mark.benchmark(
    group="AlexNet",
    warmup=True
)
def test_AlexNetPytorch(benchmark):
    model = AlexNetPyTorch()
    model.to('cuda')
    model.eval()

    # FlaxとPyTorchでデータの並び順が異なることに注意
    # バッチ数、チャネル数、画像高さ、画像幅 [N, C, H, W]
    x = np.random.rand(16, 3, 224, 224).astype(np.float32)
    x = torch.from_numpy(x).to('cuda')

    def run(_x):
        with torch.no_grad():
            return model(_x)

    benchmark.pedantic(run, args=(x,), warmup_rounds=2, iterations=20, rounds=3)

@pytest.mark.benchmark(
    group="AlexNet",
    warmup=True
)
def test_AlexNetFlax(benchmark):
    model = AlexNetFlax()

    key1, key2 = jax.random.split(jax.random.PRNGKey(0))

    # FlaxとPyTorchでデータの並び順が異なることに注意
    # バッチ数、画像高さ、画像幅、チャネル数 [N, H, W, C]
    x = jax.random.normal(key1, (16, 224, 224, 3))
    weight = model.init(key2, x) # Initialization cal

    @jax.jit
    def run(_x):
        y = model.apply(weight, _x)
        # JAXは非同期実行するのでベンチのため結果がでるのを待ちます。
        jax.block_until_ready(y)
        return y

    # warm_upラウンドを2回いれることで、jitの時間を除外する
    benchmark.pedantic(run, args=(x,), warmup_rounds=2, iterations=20, rounds=3)

if __name__ == "__main__":
    pytest.main(['-v', __file__])

実行

全部の条件のベンチを行うコマンドはこちらです。

pytest benchmark_main.py --benchmark-compare

結果はこのようにグループで出力されます。

--------------------------------------------------------------------------- benchmark 'AlexNet': 2 tests ---------------------------------------------------------------------------
Name (time in ms)          Min               Max              Mean            StdDev            Median               IQR            Outliers       OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_AlexNetFlax        2.7234 (1.0)      2.7352 (1.0)      2.7289 (1.0)      0.0059 (1.0)      2.7283 (1.0)      0.0088 (1.0)           1;0  366.4413 (1.0)           3          20
test_AlexNetPytorch     3.7357 (1.37)     3.9136 (1.43)     3.8180 (1.40)     0.0897 (15.15)    3.8047 (1.39)     0.1335 (15.09)         1;0  261.9172 (0.71)          3          20
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------- benchmark 'GoogleNet': 2 tests ----------------------------------------------------------------------------
Name (time in ms)             Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_GoogleNetFlax        36.6693 (1.0)      37.2125 (1.0)      36.9712 (1.0)      0.2766 (1.17)     37.0316 (1.0)      0.4074 (1.15)          1;0  27.0481 (1.0)           3          20
test_GoogleNetPytorch     80.8444 (2.20)     81.3176 (2.19)     81.0850 (2.19)     0.2367 (1.0)      81.0931 (2.19)     0.3549 (1.0)           1;0  12.3327 (0.46)          3          20
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

個別に実行する場合

# ひとつだけ
pytest .\benchmark_main.py::test_AlexNetPytorch
# 複数
pytest .\benchmark_main.py -k "test_GoogleNetPytorch or test_GoogleNetFlax"

参考

No comments:

Post a Comment