Published on

Colab TPUでtransformers Flax+JAXで文章分類を試してみる

Author
株式会社イエローバックの機械学習エンジニアです

はじめに

以前のポストtransformers Flax+JAX で文章分類を試してみるで、 ローカル GPU を使って Flax/JAX の文章分類を試しました。 今回は、Google Colab で、TPU(v2-8)を用いて文章分類を試してみます。

Colab notebook

transformers のflax exampleに記載されている分類用のノートブックを使います。

前回と同様にlivedoor ニュースコーパスを用いて文章の分類をします。 学習用のデータとしては、ここに示した方法により生成した train.csv,dev.csv を使います。

notebook の修正

ローカルデータを読み込んでラベルの加工を追加や、 オリジナルのノートブックのままではエラーになる箇所を修正しました。

以下、変更箇所概要です。全体はここに置いてあります。

python パッケージなど

!pip install datasets
#!pip install git+https://github.com/huggingface/transformers.git
!pip install transformers
!pip install flax
#!pip install git+https://github.com/deepmind/optax.git
!pip install optax
!pip install ipadic fugashi

ipadic と fugashi を追加、transformers と optax は head ではなくリリース最新版を利用します。

モデルと batch size

task = None
model_checkpoint = "cl-tohoku/bert-base-japanese-whole-word-masking"
per_device_batch_size = 4

task は自前のタスク、model は東北大の日本語 BERT モデルを利用します。 デバイス毎のバッチサイズは 4 としました。ちなみに 8 にすると OOM となってしまいました。

データセット読み込み

data_files = {
    "train": "drive/MyDrive/colab/ldn/data/train.csv",
    "validation": "drive/MyDrive/colab/ldn/data/dev.csv"
}

raw_dataset = load_dataset("csv", data_files=data_files)
metric = load_metric("accuracy")

データセットは、mount 済の Google Drive から読み込みます。

abel_list = raw_dataset["train"].unique("label")
label_list.sort()  # Let's sort it for determinism
num_labels = len(label_list)
label_to_id = {v: i for i, v in enumerate(label_list)}

ラベルリストの生成をします。

データセット前処理

sentence1_key, sentence2_key = ("text", None)

def preprocess_function(examples):
    texts = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    processed = tokenizer(*texts, padding="max_length", max_length=512, truncation=True)

    processed["labels"] = [label_to_id[l] for l in examples["label"]]
    return processed

センテンスは text1 文のみ、labels を数値のリストに変換します。

flax 関連

from flax import traverse_util

traverse_utilが import されていなかったので追加します。

学習用パラメタ

num_train_epochs = 5
learning_rate = 5e-5

epoch=5, lr=5e-5 とします。

TrainState

state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(0.0),
    logits_function=eval_function,
    loss_function=loss_function,
)

存在していないgradient_transformationを adamw(0.0)に変更します。adamw の引数はweight_decayです。以前の GPU での確認時と同様に 0.0 としています。

環境など

今回の確認環境は以下のようになりました。

  • Memory: 13GB (Colab Pro)
  • TPU: v2-8
  • python: 3.7
  • transformers: 4.10.0
  • jax: 0.2.19
  • jaxlib: 0.1.70+cuda110
  • flax: 0.3.4

ファインチューニング

ランタイムのタイプを TPU にして、すべてのセルを実行します。

The overall batch size (both for training and eval) is 32

total batch size = 32 (4 * 8)となっています。

Epoch ...: 100%   5/5 [15:58<00:00, 136.04s/it]
1/5 | Train loss: 0.058 | Eval accuracy: 0.922
2/5 | Train loss: 0.03 | Eval accuracy: 0.952
3/5 | Train loss: 0.015 | Eval accuracy: 0.95
4/5 | Train loss: 0.003 | Eval accuracy: 0.95
5/5 | Train loss: 0.002 | Eval accuracy: 0.951

学習ループに入ると 9 分程度止まっているような感じですが、 その後プログレスバーが動き始めます。

5epoch の学習は 16 分程度で終了しました。

プログレスバーを観察していると、1epoch 目がスタートするまで 9 分以上かかかっていて、 それ以降は 90 秒/epoch 程度(2.2it/sec)程度のペースで学習が進んでいました。

(試しに epoch=10 としたところ、26 分弱で学習が終わりました)

以前にローカル GPU(RTX2080Ti)で試したときとの比較をしてみました。

学習時間(秒)Accuracy
flax(jax) on TPU(v2-8)9580.951
flax(jax) on GPU(RTX2080Ti)1,1030.951
pytorch on GPU(RTX2080Ti)1,1180.951

5epoch だと TPU の効果があまり発揮できていないようです。

おわりに

transformers + Flax,JAX で TPU を使って文章分類の学習を試してみました。 Accuracy は GPU と同程度の結果を得ることができました。

epoch を増やすと TPU の効果が発揮できそうです。