- Published on
GPUでのJax/Flaxモデルのロード時メモリエラーの対処法
- Author
- Name
- Hideki Ono
- @yellowback
株式会社イエローバックの機械学習エンジニアです
今回は transformers で jax/flax モデルを GPU で扱うときの tips です。(transformers 限定の話ではありません)
jax/flax のサンプルコードをローカルの GPU(RTX 2080Ti)で実行したときにエラーとなったのでその対処方法です。
サンプルコード
ここに記載の以下のコードを利用します。
from transformers import FlaxRobertaModel, RobertaTokenizerFast
from datasets import load_dataset
import jax
dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
dummy_input = next(iter(dataset))["text"]
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]
model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
model(input_ids)
これをローカル GPU 環境で実行すると以下のエラーが出ました。
Traceback (most recent call last):
File "jaxsample.py", line 12, in <module>
model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")
File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/transformers/modeling_flax_utils.py", line 343, in from_pretrained
model = cls(config, *model_args, **model_kwargs)
File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/transformers/models/roberta/modeling_flax_roberta.py", line 543, in __init__
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/transformers/modeling_flax_utils.py", line 105, in __init__
random_params = self.init_weights(self.key, input_shape)
File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/transformers/models/roberta/modeling_flax_roberta.py", line 552, in init_weights
params_rng, dropout_rng = jax.random.split(rng)
File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/_src/random.py", line 143, in split
return _return_prng_keys(wrapped, _split(key, num))
File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/_src/random.py", line 129, in _split
return key._split(num)
File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/_src/prng.py", line 181, in _split
return PRNGKeyArray(self.impl, self.impl.split(self.keys, num))
File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/_src/prng.py", line 410, in threefry_split
return _threefry_split(key, int(num)) # type: ignore
File "/home/ono/py38-ds-cuda11/lib/python3.8/site-packages/jax/interpreters/xla.py", line 960, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: INTERNAL: CustomCall failed: jaxlib/cuda_prng_kernels.cc:35: operation cudaGetLastError() failed: out of memory
モデル初期化の際に、jax.random.split(rng)
をコールして OOM エラーを起こしているようです。
環境
ローカル環境は以下の通りです。
- GPU: RTX2080Ti (Mem:11GB)
- python: 3.8
- CuDA: 11.1
- transformers: 4.10.3
- jax: 0.2.21
- jaxlib: 0.1.71+cuda111
- flax: 0.3.5
- datasets: 1.12.1
解決方法
同じコードを Google Colab の GPU(P100,Mem:16GB)で実行すると正常に実行されます。 GPU メモリ容量に関連しそうと推測しました。
Jax のGPU memory allocation によると、デフォルトでは GPU のメモリの 90%がプリアロケートされます。プリアロケーションの動作をカスタマイズするには、XLA_PYTHON_CLIENT_MEM_FRACTION
やXLA_PYTHON_CLIENT_PREALLOCATE
などの環境変数を設定します。
このプリアロケーションを減らしてあげることで上記のエラーをなくすことができました。
具体的には以下のようにします。
プリアロケーションを 80%にする
$ XLA_PYTHON_CLIENT_MEM_FRACTION=.8 python jaxsample.py
プリアロケートしない
$ XLA_PYTHON_CLIENT_PREALLOCATE=false python jaxsample.py
参考
上記スクリプトを実行したときに、GPU 使用メモリ量推移を調べてみました。
プリアロケート(90%)
$ nvidia-smi --query-gpu=timestamp,memory.used -l 1 --format=csv
2021/09/27 10:14:31.767, 0 MiB
2021/09/27 10:14:32.768, 9934 MiB
2021/09/27 10:14:33.768, 10488 MiB
2021/09/27 10:14:34.769, 11016 MiB
2021/09/27 10:14:35.770, 0 MiB <---エラー
プリアロケート(80%)
XLA_PYTHON_CLIENT_MEM_FRACTION=.8
$ nvidia-smi --query-gpu=timestamp,memory.used -l 1 --format=csv
2021/09/27 10:13:57.729, 0 MiB
2021/09/27 10:13:58.733, 8848 MiB
2021/09/27 10:13:59.733, 9408 MiB
2021/09/27 10:14:00.734, 9966 MiB
2021/09/27 10:14:01.735, 10096 MiB
2021/09/27 10:14:02.735, 10096 MiB
2021/09/27 10:14:03.736, 10096 MiB
2021/09/27 10:14:04.737, 10262 MiB
2021/09/27 10:14:05.737, 10106 MiB
2021/09/27 10:14:06.738, 0 MiB <---正常終了
プリアロケートなし
XLA_PYTHON_CLIENT_PREALLOCATE=false
$ nvidia-smi --query-gpu=timestamp,memory.used -l 1 --format=csv
2021/09/27 11:32:46.222, 0 MiB
2021/09/27 11:32:47.222, 158 MiB
2021/09/27 11:32:48.222, 614 MiB
2021/09/27 11:32:49.222, 1194 MiB
2021/09/27 11:32:50.222, 1406 MiB
2021/09/27 11:32:51.222, 1406 MiB
2021/09/27 11:32:52.223, 1406 MiB
2021/09/27 11:32:53.223, 1656 MiB
2021/09/27 11:32:54.223, 1700 MiB
2021/09/27 11:32:55.223, 0 MiB <---正常終了
おわりに
今回は、メモリ少なめな GPU で jax/flax モデルを使うときの tips として、環境変数のXLA_PYTHON_CLIENT_MEM_FRACTION=.8
やXLA_PYTHON_CLIENT_PREALLOCATE=false
が使えることを紹介しました。