Published on

transformers Flax+JAXで文章分類を試してみる

Author
株式会社イエローバックの機械学習エンジニアです

はじめに

pytorch や tensorflow にならぶ DL フレームワークとして Flax,JAX が流行っているようです。 JAX は自動微分や GPU/TPU のアクセラレータが付いた numpy のようなもの、Flax は JAX 向けのハイパフォーマンス NN ライブラリです。 Flax,JAX ともに google 製ですが tensorflow とは別物です。

Huggingface transformers でも Flax,JAX に対応しつつあります。そこで今回は transformers+Flax,JAX を使って文章分類を試してみます。

transformers での Flax,JAX 対応状況

transformers 側の Flax,JAX 対応としては最低限必要となるモデルの対応や、実際に学習などに利用する際のサンプルコードの対応などがあります。

モデル対応状況

transformers 4.6.1 の時点で Flax,JAX に対応しているモデルは bert, electra, roberta の 3 モデルのみのようです。 master ブランチには、gpt2, bart, big_bird, vit, clip などがマージされるなど続々と対応モデルが増えています。

サンプル学習コード対応状況

examples/flax 以下に Flax,JAX のサンプルコードがあります。transformers 4.6.1 の時点では以下の 2 つのみです。

  • language-modeling/run_mlm_flax.py (言語モデル)
  • text-classification/run_flax_glue.py (テキスト分類)

こちらも徐々に増えてくるのではないでしょうか。

インストール

まずは以前の環境構築の記事により CuDA や transformers などをインストールします。

次に、jax,flax などをインストールします。 jax のドキュメントを参考にして、jax と CuDA11.1 対応の jaxlib をインストールします。 flax 使用時に flax.metrics.tensorboard が import されますが、tensorflow も import していますので tensorflow も必要となります。

$ pip install --upgrade jax jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
$ pip install flax
$ pip install tensorflow

環境とインストールされたパッケージ類は以下の通りです。

  • GPU: RTX2080Ti (Mem:11GB)
  • python: 3.8
  • CuDA: 11.1
  • transformers: 4.6.1
  • jax: 0.2.14
  • jaxlib: 0.1.67+cuda111
  • flax: 0.3.4
  • (pytorch: 1.8.1+cu111)

データセット

以前も使用したlivedoor ニュースコーパスを使います。 ここに示した方法により train.csv,dev.csv を作成します。

モデル

transformers+flax ではまだ mixed precision training の機能が充分ではないので 32bit で学習する必要があります。 GPU メモリの制約で大きなモデルは使えないので、今回はモデルとしてbert-base-japanese-whole-word-maskingを使います。

bert-base-japanese-whole-word-maskingでは、flax 用の weight ファイル(flax_model.msgpack)が提供されているのでそのまま使えます。

学習用スクリプト

transformers 4.6.1 用の文章分類用サンプルスクリプトを利用します。

$ git clone https://github.com/huggingface/transformers.git
$ cd transformers
$ git checkout -b v4.6.1 v4.6.1

ファインチューニング

pytorch 用の run_glue.py と比べると --max_seq_length--max_lengthになるなどオプションが少し違っています。

$ BS=8
$ SEQLEN=512
$ python ./examples/flax/text-classification/run_flax_glue.py
       --model_name_or_path cl-tohoku/bert-base-japanese-whole-word-masking \
       --output_dir ./output \
       --per_device_train_batch_size $BS \
       --per_device_eval_batch_size $BS \
       --train_file ~/ldn/data/train.csv \
       --validation_file ~/ldn/data/dev.csv \
       --num_train_epochs 5 \
       --max_length $SEQLEN

nvidia-smi コマンドで確認すると、ちゃんと GPU を使って学習が行われているようです。 学習は 5epoch で、20 分弱で終了しました。

   ===== Starting training (5 epochs) =====
   Epoch 1
   Training...

         .
         .

   Epoch 5
   Training...
       Done! Training metrics: {'learning_rate': DeviceArray(1.3589859e-08, dtype=float32), 'loss': DeviceArray(0.00091459, dtype=float32)}
     Evaluating...
       Done! Eval metrics: {'accuracy': 0.9511533242876526}

pytorch バージョンでも 32bit で同様の学習をして、学習時間と accuracy を比較してみました。

学習時間(秒)Accuracy
flax(jax)1,1030.951
pytorch1,1180.951

学習時間、Accuracy ともに同程度の結果となりました。

備考

flax 用の weight ファイル(flax_model.msgpack)が提供されていないモデルでは、run_flax_glue.py を以下のように修正することで、pytorch 用の weight ファイル(pytorch.bin)を読み込んでくれるようです。

--- run_flax_glue.py  2021-05-31 17:30:21.527718856 +0900
+++ run_flax_glue.py    2021-06-14 17:34:21.750590205 +0900
@@ -327,7 +327,7 @@
     # Load pretrained model and tokenizer
     config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
     tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
-    model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)
+    model = FlaxAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config, from_pt=True)

     # Preprocessing the datasets
     if args.task_name is not None:

おわりに

transformers + Flax,JAX で GPU を使って文章分類の学習を試してみました。 結果として pytorch 版と同程度の結果を得ることができました。

現状では、Flax,JAX では Mixed Precision(fp16)が使えないなど不便なところがありますが、 開発は活発に行われているようですので今後に期待したいです。