- Published on
Colab TPUでtransformers Flax+JAXで文章分類を試してみる
- Author
- Name
- Hideki Ono
- @yellowback
株式会社イエローバックの機械学習エンジニアです
はじめに
以前のポストtransformers Flax+JAX で文章分類を試してみるで、 ローカル GPU を使って Flax/JAX の文章分類を試しました。 今回は、Google Colab で、TPU(v2-8)を用いて文章分類を試してみます。
Colab notebook
transformers のflax exampleに記載されている分類用のノートブックを使います。
前回と同様にlivedoor ニュースコーパスを用いて文章の分類をします。 学習用のデータとしては、ここに示した方法により生成した train.csv,dev.csv を使います。
notebook の修正
ローカルデータを読み込んでラベルの加工を追加や、 オリジナルのノートブックのままではエラーになる箇所を修正しました。
以下、変更箇所概要です。全体はここに置いてあります。
python パッケージなど
!pip install datasets
#!pip install git+https://github.com/huggingface/transformers.git
!pip install transformers
!pip install flax
#!pip install git+https://github.com/deepmind/optax.git
!pip install optax
!pip install ipadic fugashi
ipadic と fugashi を追加、transformers と optax は head ではなくリリース最新版を利用します。
モデルと batch size
task = None
model_checkpoint = "cl-tohoku/bert-base-japanese-whole-word-masking"
per_device_batch_size = 4
task は自前のタスク、model は東北大の日本語 BERT モデルを利用します。 デバイス毎のバッチサイズは 4 としました。ちなみに 8 にすると OOM となってしまいました。
データセット読み込み
data_files = {
"train": "drive/MyDrive/colab/ldn/data/train.csv",
"validation": "drive/MyDrive/colab/ldn/data/dev.csv"
}
raw_dataset = load_dataset("csv", data_files=data_files)
metric = load_metric("accuracy")
データセットは、mount 済の Google Drive から読み込みます。
abel_list = raw_dataset["train"].unique("label")
label_list.sort() # Let's sort it for determinism
num_labels = len(label_list)
label_to_id = {v: i for i, v in enumerate(label_list)}
ラベルリストの生成をします。
データセット前処理
sentence1_key, sentence2_key = ("text", None)
def preprocess_function(examples):
texts = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
processed = tokenizer(*texts, padding="max_length", max_length=512, truncation=True)
processed["labels"] = [label_to_id[l] for l in examples["label"]]
return processed
センテンスは text1 文のみ、labels を数値のリストに変換します。
flax 関連
from flax import traverse_util
traverse_util
が import されていなかったので追加します。
学習用パラメタ
num_train_epochs = 5
learning_rate = 5e-5
epoch=5, lr=5e-5 とします。
TrainState
state = TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=adamw(0.0),
logits_function=eval_function,
loss_function=loss_function,
)
存在していないgradient_transformation
を adamw(0.0)に変更します。adamw の引数はweight_decay
です。以前の GPU での確認時と同様に 0.0 としています。
環境など
今回の確認環境は以下のようになりました。
- Memory: 13GB (Colab Pro)
- TPU: v2-8
- python: 3.7
- transformers: 4.10.0
- jax: 0.2.19
- jaxlib: 0.1.70+cuda110
- flax: 0.3.4
ファインチューニング
ランタイムのタイプを TPU にして、すべてのセルを実行します。
The overall batch size (both for training and eval) is 32
total batch size = 32 (4 * 8)となっています。
Epoch ...: 100% 5/5 [15:58<00:00, 136.04s/it]
1/5 | Train loss: 0.058 | Eval accuracy: 0.922
2/5 | Train loss: 0.03 | Eval accuracy: 0.952
3/5 | Train loss: 0.015 | Eval accuracy: 0.95
4/5 | Train loss: 0.003 | Eval accuracy: 0.95
5/5 | Train loss: 0.002 | Eval accuracy: 0.951
学習ループに入ると 9 分程度止まっているような感じですが、 その後プログレスバーが動き始めます。
5epoch の学習は 16 分程度で終了しました。
プログレスバーを観察していると、1epoch 目がスタートするまで 9 分以上かかかっていて、 それ以降は 90 秒/epoch 程度(2.2it/sec)程度のペースで学習が進んでいました。
(試しに epoch=10 としたところ、26 分弱で学習が終わりました)
以前にローカル GPU(RTX2080Ti)で試したときとの比較をしてみました。
学習時間(秒) | Accuracy | |
---|---|---|
flax(jax) on TPU(v2-8) | 958 | 0.951 |
flax(jax) on GPU(RTX2080Ti) | 1,103 | 0.951 |
pytorch on GPU(RTX2080Ti) | 1,118 | 0.951 |
5epoch だと TPU の効果があまり発揮できていないようです。
おわりに
transformers + Flax,JAX で TPU を使って文章分類の学習を試してみました。 Accuracy は GPU と同程度の結果を得ることができました。
epoch を増やすと TPU の効果が発揮できそうです。