- Published on
transformers Flax+JAXで文章分類を試してみる
- Author
- Name
- Hideki Ono
- @yellowback
株式会社イエローバックの機械学習エンジニアです
はじめに
pytorch や tensorflow にならぶ DL フレームワークとして Flax,JAX が流行っているようです。 JAX は自動微分や GPU/TPU のアクセラレータが付いた numpy のようなもの、Flax は JAX 向けのハイパフォーマンス NN ライブラリです。 Flax,JAX ともに google 製ですが tensorflow とは別物です。
Huggingface transformers でも Flax,JAX に対応しつつあります。そこで今回は transformers+Flax,JAX を使って文章分類を試してみます。
transformers での Flax,JAX 対応状況
transformers 側の Flax,JAX 対応としては最低限必要となるモデルの対応や、実際に学習などに利用する際のサンプルコードの対応などがあります。
モデル対応状況
transformers 4.6.1 の時点で Flax,JAX に対応しているモデルは bert, electra, roberta の 3 モデルのみのようです。 master ブランチには、gpt2, bart, big_bird, vit, clip などがマージされるなど続々と対応モデルが増えています。
サンプル学習コード対応状況
examples/flax 以下に Flax,JAX のサンプルコードがあります。transformers 4.6.1 の時点では以下の 2 つのみです。
- language-modeling/run_mlm_flax.py (言語モデル)
- text-classification/run_flax_glue.py (テキスト分類)
こちらも徐々に増えてくるのではないでしょうか。
インストール
まずは以前の環境構築の記事により CuDA や transformers などをインストールします。
次に、jax,flax などをインストールします。 jax のドキュメントを参考にして、jax と CuDA11.1 対応の jaxlib をインストールします。 flax 使用時に flax.metrics.tensorboard が import されますが、tensorflow も import していますので tensorflow も必要となります。
$ pip install --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
$ pip install flax
$ pip install tensorflow
環境とインストールされたパッケージ類は以下の通りです。
- GPU: RTX2080Ti (Mem:11GB)
- python: 3.8
- CuDA: 11.1
- transformers: 4.6.1
- jax: 0.2.14
- jaxlib: 0.1.67+cuda111
- flax: 0.3.4
- (pytorch: 1.8.1+cu111)
データセット
以前も使用したlivedoor ニュースコーパスを使います。 ここに示した方法により train.csv,dev.csv を作成します。
モデル
transformers+flax ではまだ mixed precision training の機能が充分ではないので 32bit で学習する必要があります。
GPU メモリの制約で大きなモデルは使えないので、今回はモデルとしてbert-base-japanese-whole-word-masking
を使います。
bert-base-japanese-whole-word-masking
では、flax 用の weight ファイル(flax_model.msgpack)が提供されているのでそのまま使えます。
学習用スクリプト
transformers 4.6.1 用の文章分類用サンプルスクリプトを利用します。
$ git clone https://github.com/huggingface/transformers.git
$ cd transformers
$ git checkout -b v4.6.1 v4.6.1
ファインチューニング
pytorch 用の run_glue.py と比べると --max_seq_length
が--max_length
になるなどオプションが少し違っています。
$ BS=8
$ SEQLEN=512
$ python ./examples/flax/text-classification/run_flax_glue.py
--model_name_or_path cl-tohoku/bert-base-japanese-whole-word-masking \
--output_dir ./output \
--per_device_train_batch_size $BS \
--per_device_eval_batch_size $BS \
--train_file ~/ldn/data/train.csv \
--validation_file ~/ldn/data/dev.csv \
--num_train_epochs 5 \
--max_length $SEQLEN
nvidia-smi コマンドで確認すると、ちゃんと GPU を使って学習が行われているようです。 学習は 5epoch で、20 分弱で終了しました。
===== Starting training (5 epochs) =====
Epoch 1
Training...
.
.
Epoch 5
Training...
Done! Training metrics: {'learning_rate': DeviceArray(1.3589859e-08, dtype=float32), 'loss': DeviceArray(0.00091459, dtype=float32)}
Evaluating...
Done! Eval metrics: {'accuracy': 0.9511533242876526}
pytorch バージョンでも 32bit で同様の学習をして、学習時間と accuracy を比較してみました。
学習時間(秒) | Accuracy | |
---|---|---|
flax(jax) | 1,103 | 0.951 |
pytorch | 1,118 | 0.951 |
学習時間、Accuracy ともに同程度の結果となりました。
備考
flax 用の weight ファイル(flax_model.msgpack)が提供されていないモデルでは、run_flax_glue.py を以下のように修正することで、pytorch 用の weight ファイル(pytorch.bin)を読み込んでくれるようです。
--- run_flax_glue.py 2021-05-31 17:30:21.527718856 +0900
+++ run_flax_glue.py 2021-06-14 17:34:21.750590205 +0900
@@ -327,7 +327,7 @@
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
- model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)
+ model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config, from_pt=True)
# Preprocessing the datasets
if args.task_name is not None:
おわりに
transformers + Flax,JAX で GPU を使って文章分類の学習を試してみました。 結果として pytorch 版と同程度の結果を得ることができました。
現状では、Flax,JAX では Mixed Precision(fp16)が使えないなど不便なところがありますが、 開発は活発に行われているようですので今後に期待したいです。