Published on

transformersでdeepspeedを使ってみる(文章分類編)

Author
株式会社イエローバックの機械学習エンジニアです

前回のおさらい

前回の記事で deepspeed を使用するための環境構築を行いました。 今回は、この環境を使って実際に学習を試してみます。

環境は以下の通りです。

  • GPU: RTX2080Ti (Mem:11GB)
  • python: 3.8
  • CuDA: 11.1
  • pytorch: 1.8.1
  • transformers: 4.6.1
  • deepspeed: 0.3.16

補足

Microsoft DeepSpeed と同様のライブラリに Facebook の FairScale が存在しています。 FairScale を動作させるには 2GPU 以上が必要なので、今回の目的の 1GPU には適さないため DeepSpeed を採用しました。

ZeRO-2

deepspeed 0.3.16 では Zero Redundancy Optimizer(ZeRO)の stage2(ZeRO-2)および stage3(ZeRO-3)が実装されています。 transformers 4.6.1 でも deepspeed(ZeRO-2)および deepspeed(ZeRO-3)用のコードが含まれています。 本記事ではまず ZeRO-2 の動作を確認してみます。

参考: Fit More and Train Faster With ZeRO via DeepSpeed and FairScale

文章分類

今回のターゲットは文章分類としてその学習をおこないます。 モデルには少し大きめの xlm-roberta-large(パラメタ数:559M)を利用します。

データセット

日本語の文章分類用のデータセットとしてはメジャーなlivedoor ニュースコーパスを使います。 これから train.csv,dev.csv を作成します。

まずhttps://nxdataka.netlify.app/ldncsv/で公開されている ldn2csv.py を使って上記のコーパスを csv に変換します。

次に以下のスクリプトで、train と dev に分割します。

split-livedoornews.py
import pandas as pd
from sklearn.model_selection import train_test_split

data = pd.read_csv("livedoornews.csv")

def remove_nl(line):
    return line.replace('\n', ' ')

data['text'] = data['body'].apply(remove_nl)
data['label'] = data['media']

data2 = data[['label', 'text']]

x_train, x_dev, = train_test_split(data2, train_size=0.8, test_size=0.2, random_state=55)

x_train.to_csv('train.csv', index=False)
x_dev.to_csv('dev.csv', index=False)
$ python ldncsv.py
$ python split-livedoornews.py
$ mkdir -p ~/ldn/data
$ cp train.csv dev.csv ~/ldn/data

学習用スクリプト

transformers 4.6.1 用の文章分類用サンプルスクリプトを利用します。

$ git clone https://github.com/huggingface/transformers.git
$ cd transformers
$ git checkout -b v4.6.1 v4.6.1

GPU で学習(deepspeed なし)

$ BS=1
$ SEQLEN=512
$ python examples/pytorch/text-classification/run_glue.py \
       --model_name_or_path xlm-roberta-large \
       --output_dir ~/ldn/output \
       --per_device_train_batch_size $BS \
       --per_device_eval_batch_size $BS \
       --gradient_accumulation_steps 1 \
       --max_seq_length $SEQLEN \
       --train_file ~/ldn/data/train.csv \
       --validation_file ~/ldn/data/dev.csv \
       --test_file ~/ldn/data/dev.csv \
       --num_train_epochs 1 \
       --fp16 \
       --fp16_opt_level O1 \
       --do_train \
       --do_eval \
       --do_predict

実行すると以下のような OOM エラーとなります。

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.76 GiB total capacity; 9.42 GiB already allocated; 5.44 MiB free; 9.44 GiB reserved in total by PyTorch)

xlm-roberta-large は seqlen=512 ではバッチサイズ 1 でもエラーになってしまうので、通常の方法では GPU で学習できないということになります。

Deepspeed ZeRO-2 で学習

deepspeed を簡単に使うには上記の python を deepspeed コマンドに変更し--deepspeed <CONFIG_FILE> オプションを追加します。

zero2 config

config ファイルは、clone した transformers の tests ディレクトリに存在しています。

  • tests/deepspeed/ds_config_zero2.json

今回は一部修正して使います。

$ cp tests/deepspeed/ds_config_zero2.json ~/ldn/ds_config_zero2.json

修正箇所は、

  • オリジナルと同じくなるようにスケジューラの変更。
  • GPU のメモリが少なめなので bucket_size を減少
@@ -19,21 +19,22 @@
     },

     "scheduler": {
-        "type": "WarmupLR",
+        "type": "WarmupDecayLR",
         "params": {
             "warmup_min_lr": "auto",
             "warmup_max_lr": "auto",
-            "warmup_num_steps": "auto"
+            "warmup_num_steps": "auto",
+            "total_num_steps": "auto"
         }
     },

     "zero_optimization": {
         "stage": 2,
         "allgather_partitions": true,
-        "allgather_bucket_size": 2e8,
+        "allgather_bucket_size": 6e7,
         "overlap_comm": true,
         "reduce_scatter": true,
-        "reduce_bucket_size": 2e8,
+        "reduce_bucket_size": 6e7,
         "contiguous_gradients": true,
         "cpu_offload": true
     },
~/ldn/ds_config_zero2.json
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 6e7,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 6e7,
        "contiguous_gradients": true,
        "cpu_offload": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

学習実行

$ BS=1
$ SEQLEN=512
$ deepspeed examples/pytorch/text-classification/run_glue.py \
       --model_name_or_path xlm-roberta-large \
       --output_dir ~/ldn/output \
       --per_device_train_batch_size $BS \
       --per_device_eval_batch_size $BS \
       --gradient_accumulation_steps 1 \
       --max_seq_length $SEQLEN \
       --train_file ~/ldn/data/train.csv \
       --validation_file ~/ldn/data/dev.csv \
       --test_file ~/ldn/data/dev.csv \
       --num_train_epochs 1 \
       --fp16 \
       --fp16_opt_level O1 \
       --do_train \
       --do_eval \
       --do_predict \
       --deepspeed ~/ldn/ds_config_zero2.json

deepspeed を使うと OOM エラーにならず学習ができました。

バッチサイズと bucket_size の関係

zero2 のパラメタ allgather_bucket_sizereduce_bucket_size が大きすぎると OOM エラーになってしまいます。 また、これらが小さすぎると以下のように dimension size のエラーとなることもあります。

RuntimeError: start (0) + length (256002048) exceeds dimension size (45000000).

そこで、SEQLEN=512(固定)としたときに、バッチサイズ(BS)と bucket_size(=allgather_bucket_size=reduce_bucket_size)を変化させて学習可能かどうかを調べてみました。 結果が以下の表です。OOM は Out Of Memory エラー、DIM は dimension size のエラー、OK は学習できたことを示しています。

BS=2BS=4BS=6BS=8
オリジナル(ds 無し)OOMOOMOOMOOM
bucket_size=1e7DIMDIMDIMOOM
bucket_size=6e7OKOKOKOOM
bucket_size=2e8OKOKOOMOOM
bucket_size=6e7 w/o CPUOOMOOMOOMOOM
bucket_size=2e8 w/o CPUOOMOOMOOMOOM

また、エラーなく学習できたケースで学習時間と評価時間は以下のようになりました。

BS学習時間(秒)評価時間(秒)
bucket_size=6e761,49418.6
bucket_size=6e741,95718.6
bucket_size=6e723,75819.4
bucket_size=2e842,10618.6
bucket_size=2e823,93919.4

バッチサイズ 6 のケースがもっとも早く終わりました。

CPU オフロード

次に CPU オフロードの効果を確認します。さきほど設定した deepspeed config ファイル(ds_config_zero2.json)では CPU オフロードが有効になっていましたが、 以下のようにしてオフにします。

    "zero_optimization": {
              .
              .
        "cpu_offload": false
    },

確認結果は以下になります。

BS=1BS=2BS=4BS=6BS=8
bucket_size=6e7 CPU オフロード有OKOKOKOOM
bucket_size=2e8 CPU オフロード有OKOKOOMOOM
bucket_size=6e7 CPU オフロード無OOMOOMOOMOOMOOM
bucket_size=2e8 CPU オフロード無OOMOOMOOMOOMOOM

CPU オフロードを無効にすると、BS=1 でも OOM となってしまいました。

bert-large-japanese

モデル cl-tohoku/bert-large-japanese でも同様の実験をおこないました。bert-large-japanese はパラメタ数 337M と xlm-roberta-large(559M)よりも少ないです。

結果は以下です。(SEQLEN=512)

BS=2BS=4BS=6BS=8
オリジナル(DS 無し)OKOOMOOMOOM
bucket_size=1e7OKOKOKOOM
bucket_size=6e7OKOKOKOOM
bucket_size=3e8OKOOMOOMOOM
bucket_size=1e7 w/o CPUOKOOMOOMOOM
bucket_size=6e7 w/o CPUOKOOMOOMOOM
bucket_size=3e8 w/o CPUOOMOOMOOMOOM

学習できたものの学習時間は以下のようになりました。

BS学習時間(秒)
オリジナル(DS 無し)2580
bucket_size=1e722,346
bucket_size=1e741,267
bucket_size=1e76909
bucket_size=6e722,321
bucket_size=6e741,261
bucket_size=6e76908
bucket_size=3e822,349
bucket_size=1e7 w/o CPU2599
bucket_size=6e7 w/o CPU2582

学習時間は CPU オフロードを有効にした場合、オリジナルに比べて大きめのバッチサイズまで学習可能となります。 CPU オフロードを無効にすると、オリジナルと同バッチサイズ程度しか学習できないようです。

おわりに

本記事では deepspeed ZeRO-2 を使って、1GPU での文章分類学習を試してみました。 通常では GPU で学習できないようなモデルに対しては有効であることが確認できました。 ZeRO-2 使用時には bucket_size パラメタの調整が必要となるようです。