Published on

GPUでのJax/Flaxモデルのロード時メモリエラーの対処法

Author
株式会社イエローバックの機械学習エンジニアです

今回は transformers で jax/flax モデルを GPU で扱うときの tips です。(transformers 限定の話ではありません)

jax/flax のサンプルコードをローカルの GPU(RTX 2080Ti)で実行したときにエラーとなったのでその対処方法です。

サンプルコード

https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md#local-computer

ここに記載の以下のコードを利用します。

jaxsample.py
from transformers import FlaxRobertaModel, RobertaTokenizerFast
from datasets import load_dataset
import jax

dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)

dummy_input = next(iter(dataset))["text"]

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]

model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")

# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
model(input_ids)

これをローカル GPU 環境で実行すると以下のエラーが出ました。

Traceback (most recent call last):
  File "jaxsample.py", line 12, in <module>
    model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
  File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/transformers/modeling_flax_utils.py", line 343, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/transformers/models/roberta/modeling_flax_roberta.py", line 543, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
  File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/transformers/modeling_flax_utils.py", line 105, in __init__
    random_params = self.init_weights(self.key, input_shape)
  File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/transformers/models/roberta/modeling_flax_roberta.py", line 552, in init_weights
    params_rng, dropout_rng = jax.random.split(rng)
  File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/_src/random.py", line 143, in split
    return _return_prng_keys(wrapped, _split(key, num))
  File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/_src/random.py", line 129, in _split
    return key._split(num)
  File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/_src/prng.py", line 181, in _split
    return PRNGKeyArray(self.impl, self.impl.split(self.keys, num))
  File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/_src/prng.py", line 410, in threefry_split
    return _threefry_split(key, int(num))  # type: ignore
  File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/interpreters/xla.py", line 960, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
RuntimeError: INTERNAL: CustomCall failed: jaxlib/cuda_prng_kernels.cc:35: operation cudaGetLastError() failed: out of memory

モデル初期化の際に、jax.random.split(rng)をコールして OOM エラーを起こしているようです。

環境

ローカル環境は以下の通りです。

  • GPU: RTX2080Ti (Mem:11GB)
  • python: 3.8
  • CuDA: 11.1
  • transformers: 4.10.3
  • jax: 0.2.21
  • jaxlib: 0.1.71+cuda111
  • flax: 0.3.5
  • datasets: 1.12.1

解決方法

同じコードを Google Colab の GPU(P100,Mem:16GB)で実行すると正常に実行されます。 GPU メモリ容量に関連しそうと推測しました。

Jax のGPU memory allocation によると、デフォルトでは GPU のメモリの 90%がプリアロケートされます。プリアロケーションの動作をカスタマイズするには、XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATEなどの環境変数を設定します。

このプリアロケーションを減らしてあげることで上記のエラーをなくすことができました。

具体的には以下のようにします。

プリアロケーションを 80%にする

$ XLA_PYTHON_CLIENT_MEM_FRACTION=.8 python jaxsample.py

プリアロケートしない

$ XLA_PYTHON_CLIENT_PREALLOCATE=false python jaxsample.py

参考

上記スクリプトを実行したときに、GPU 使用メモリ量推移を調べてみました。

プリアロケート(90%)

$ nvidia-smi --query-gpu=timestamp,memory.used -l 1 --format=csv
2021/09/27 10:14:31.767, 0 MiB
2021/09/27 10:14:32.768, 9934 MiB
2021/09/27 10:14:33.768, 10488 MiB
2021/09/27 10:14:34.769, 11016 MiB
2021/09/27 10:14:35.770, 0 MiB        <---エラー

プリアロケート(80%)

XLA_PYTHON_CLIENT_MEM_FRACTION=.8

$ nvidia-smi --query-gpu=timestamp,memory.used -l 1 --format=csv
2021/09/27 10:13:57.729, 0 MiB
2021/09/27 10:13:58.733, 8848 MiB
2021/09/27 10:13:59.733, 9408 MiB
2021/09/27 10:14:00.734, 9966 MiB
2021/09/27 10:14:01.735, 10096 MiB
2021/09/27 10:14:02.735, 10096 MiB
2021/09/27 10:14:03.736, 10096 MiB
2021/09/27 10:14:04.737, 10262 MiB
2021/09/27 10:14:05.737, 10106 MiB
2021/09/27 10:14:06.738, 0 MiB      <---正常終了

プリアロケートなし

XLA_PYTHON_CLIENT_PREALLOCATE=false

$ nvidia-smi --query-gpu=timestamp,memory.used -l 1 --format=csv
2021/09/27 11:32:46.222, 0 MiB
2021/09/27 11:32:47.222, 158 MiB
2021/09/27 11:32:48.222, 614 MiB
2021/09/27 11:32:49.222, 1194 MiB
2021/09/27 11:32:50.222, 1406 MiB
2021/09/27 11:32:51.222, 1406 MiB
2021/09/27 11:32:52.223, 1406 MiB
2021/09/27 11:32:53.223, 1656 MiB
2021/09/27 11:32:54.223, 1700 MiB
2021/09/27 11:32:55.223, 0 MiB       <---正常終了

おわりに

今回は、メモリ少なめな GPU で jax/flax モデルを使うときの tips として、環境変数のXLA_PYTHON_CLIENT_MEM_FRACTION=.8XLA_PYTHON_CLIENT_PREALLOCATE=falseが使えることを紹介しました。