izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録。

JAX/Flaxでゼロから作るTransformer① ―Vanilla Transformer

Preface

 このシリーズは、JAX/FlaxでゼロからTransformer系のモデルを実装し、JAX/Flaxの使い方やモデルの仕組みを理解するためのチュートリアルである。この記事では、その第一弾として、最初のTransformerであるVanilla Transformerを実装する。Vanilla Transformerという単語は、”ただの"トランスフォーマーという意味であり、"Attention is all you need"[1]で提案された最初のTransformerを指すために使われるようになったものである。この記事では、できるだけ元論文[1]の設定に従いながら、主に以下の内容を取り扱う。

  • JAX/FlaxでゼロからTransformerを実装する。

  • Multi30kの独英翻訳タスクを学習させる。

  • Greedy Searchにより翻訳する。

  • Attentionを可視化する。

また、この実装の主な特徴は以下の通りである。

  • Vanilla TransformerをMultiHeadAttentionとPositionalEncoderの自作クラスで実装する。

  • 文単位ではなく、バッチ単位で翻訳できるtranslator関数を実装している。

  • ソース文とターゲット文を最大配列長になるようにパディングすることで、jitコンパイル関数の入力形状を固定し、コンパイル回数を最小にする。それにより、学習を高速化する。(jitコンパイルされた関数は、入力の形状が変わるたびにコンパイルが行われbytecodeがキャッシュされるため、入力形状がバラバラだとかえって遅くなってしまう。そのため、入力形状を全バッチで揃えた。)

  • JAX/Flaxの実装では、クエリー方向とキー方向を同時にマスクし、マスクした位置をjnp.finfo(dtype).minに変更してAttentionを計算することが多い。しかし本実装では、Attentionの一般的な実装に倣い、キー方向のみをマスクしてマスク位置を-jnp.infに変換し、クエリー方向は損失関数内でマスクする。

なお、チュートリアルということもあり実装はColabにて行い、githubでノートブックを公開している。

github.com

colab.research.google.com

余談:バニラアイスがスタンダードなアイスであることからか、"Vanilla"という単語には、「普通の」、「ありきたりな」などの意味があるらしく、バニラアイスと同じようにこの最初のトランスフォーマーも、バニラトランスフォーマーと呼ばれるようになったようである。これは個人的には誠に遺憾である。なぜなら、僕はアイスの中ではバニラアイスが一番好きであり、僕にとっては普通でありきたりなものではないからである。それは置いといて、最近はBERTなどのEncoderだけを使ったモデルや、GPTなどのDecoderだけを使った言語モデルのほうが標準的になっているような節があり、Encoder-Decoderモデルが「普通でありふれた」ものかどうかと言われたらけっこう微妙な気がする。それに、個人的にはEncoderとDecoderの両方を使うこのVanilla Transformerが最も実装が難しいように感じ、Encoder only、Decoder onlyのモデルはこれが理解出来れば実装は比較的簡単に感じる。そのように実装面でも最初にして難関の印象があるため、「普通」という感じははあまりしない。もとはこのモデルが”Transformer"と呼ばれるものであるが、”Transformer”という言葉が指す意味が広くなっていったことで「普通の」というよりは「オリジナルの」などという意味で「バニラ」とつけたのだろうと思われる。他にも、EncoderとDecoderの両方を用いることからTransformer encoder-decoder modelと呼ばれることもある。

Vanilla Transformer from scratch with JAX/Flax

Check GPU Settings

 GPUの設定を確認する。ColabだとGPUが毎回同じとは限らないので、割り振られるたびに確認するとよい。今回はTesla T4であった。また、CUDAやDriverのversionなどはJAXとFlaxをインストールする際に確認する必要がある。

!nvidia-smi
Fri Feb 24 09:17:08 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P0    24W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
!nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_Mar__8_18:18:20_PST_2022
Cuda compilation tools, release 11.6, V11.6.124
Build cuda_11.6.r11.6/compiler.31057947_0

Install and Import Libraries

 jax, flax, optaxをインストールし、必要なライブラリをインポートする。

!pip install --upgrade pip
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install -q flax==0.6.1 optax==0.1.3
import os
import re
from typing import Dict, Optional, List, Union, Callable, Any, Tuple
import time, datetime
from functools import partial

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import jax, optax
from jax import random, numpy as jnp
import jax.lax as lax
import numpy as np

from flax import linen as nn
from flax import jax_utils
from flax.training import train_state, checkpoints
from flax.training import common_utils

Prepare Multi30k Datasets

 今回は、簡単のために翻訳用のデータセットにはtorchtextの独英翻訳Multi30kデータセットを、トークナイザにはSpaCyを利用する。これらはPytorchのチュートリアルでも用いられていので、それらと比較してみてもいいかもしれない。(公式はTensorFlow Datasetsを推奨している。)

!pip install -U spacy
!pip install torch torchvision torchtext==0.8.0
import torch
try:
  from torchtext.datasets import Multi30k
except OSError:
  from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

import spacy
import spacy.cli 
spacy.cli.download('en_core_web_sm')
spacy.cli.download('de_core_news_sm')

 SpaCyを使ってドイツ語と英語のトークナイズ関数、ソース(ドイツ語)とターゲット(英語)のフィールドを定義する。その後、訓練、検証、テスト用にデータセットを分け、イテレータを作成しする。バッチサイズは128とする。最後に、トークン(str型)のリストからID(int型)のリストに変換する関数とその逆変換をする関数を、それぞれドイツ語用と英語用に作成する。

spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

def tokenize_de(text: str):
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text: str):
    return [tok.text for tok in spacy_en.tokenizer(text)]

SRC = Field(tokenize = tokenize_de,
            tokenizer_language="de",
            init_token = '',
            eos_token = '',
            pad_token='',
            lower = True,
            batch_first = True)

TRG = Field(tokenize = tokenize_en,
            tokenizer_language="en",
            init_token = '',
            eos_token = '',
            pad_token='',
            lower = True,
            batch_first = True)

train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'),
                                                    fields = (SRC, TRG))
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size = 128)

def stoi_de(tokens: List[str]) -> List[int]:
  return [SRC.vocab.stoi[token] for token in tokens]

def stoi_en(tokens: List[str]) -> List[int]:
  return [TRG.vocab.stoi[token] for token in tokens]

def itos_de(ids: List[int]) -> List[str]:
  return [SRC.vocab.itos[id] for id in ids]

def itos_en(ids: List[int]) -> List[str]:
  return [TRG.vocab.itos[id] for id in ids]

 データセットの詳細を見る。SRC.vocabは、ソース(ドイツ語)の辞書を格納するオブジェクトで、SRC.vocab.stoiはトークンを受け取ってそれに対応するIDを返す辞書である。TRG.vocab.stoiはターゲット(英語)のそれである。辞書の大きさを見てみると、ソースでは7853トークン、ターゲットでは5893トークンあり、それぞれトークンからIDを返す辞書(stoi)とIDからトークンを返すリスト(itos)の大きさと一致していることがわかる。

print(SRC.vocab.stoi)
print(TRG.vocab.stoi)
print(len(SRC.vocab.stoi))
print(len(TRG.vocab.stoi))
print(len(SRC.vocab.itos))
print(len(TRG.vocab.itos))
print(len(SRC.vocab))
print(len(TRG.vocab))
{'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, '.': 4, 'ein': 5, 'einem': 6, 'in': 7, 'eine': 8, ',': 9, 'und': 10, 'mit': 11, 'auf': 12, ....
{'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, 'a': 4, '.': 5, 'in': 6, 'the': 7, 'on': 8, 'man': 9, 'is': 10, 'and': 11, 'of': 12, ...

7853
5893
7853
5893
7853
5893

 試しにトレインデータにある文を見てみる。イテレータの最初のバッチのソースとターゲットの最初のペアを見る。(ここで表示しているのは1ペアだけであるが、Colabノートブックでは10例表示している。)モデルの訓練ではトークンはIDとして扱うためにデータ中でも文がID列で表現されており、ID列を見ても元の文が分からないため、先ほど作成したIDのリストをトークンのリストに変換する関数を使ってトークンのリストにして表示している。

for i, batch in enumerate(train_iterator):
  src = batch.src[:10]
  trg = batch.trg[:10]
  src = list(src)
  trg = list(trg)
  for de, en in zip(src, trg):
    print(itos_de(de))
    print(itos_en(en))
    print("--------------------")
  break
['<bos>', 'drei', 'junge', 'menschen', 'pflanzen', 'blumen', 'und', 'decken', 'den', 'bereich', 'mit', 'einer', 'plane', 'ab', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['<bos>', 'three', 'young', 'people', 'planting', 'flowers', 'and', 'covering', 'the', 'area', 'with', 'a', 'tarp', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
--------------------

ソース文は「drei junge menschen pflanzen blumen und decken den bereich mit einer palen ab.」であり、ターゲット文は「three young people planting flowers and covering the area with a tarp.」であることが分かる。また、文の前に<bos>トークンが、後に<eos>トークンが、そのあとに<pad>トークンが何個もついていることが分かる。それらはスペシャトークンと呼ばれ、モデルの訓練のためなどに使われるトークンである。<unk>は辞書にない未知語を、<bos>トークンは文の始まりを、<eos>トークンは文の終了を表し、<pad>トークンはバッチごとに系列長を揃えるために余白を埋める意味のないトークンとして使用される。

 バッチごとに系列長を揃えてあるが、ここでデータセット全体で最大の系列長を調べる。モデルが対応すべき最大系列長を調べておき、翻訳時などに生成する上限として使用したり、モデルが過度に大きくなることを防いだりする際に使用する。特に、後で解説するが今回の実装では、jitコンパイルによる高速化のために、さらに<pad>トークンを追加しデータセット全体の最大系列長にすべての入力を合わせる。

max_input_len = 0
max_target_len = 0

for iterator in [train_iterator, valid_iterator, test_iterator]:
  for batch in iterator:
    src = jnp.asarray(batch.src)
    trg = jnp.asarray(batch.trg)
    if src.shape[1] >= max_input_len:
      max_input_len = src.shape[1]
    if trg.shape[1] >= max_target_len:
      max_target_len = trg.shape[1]

print('max_input_len', max_input_len)
print('max_target_len', max_target_len)
max_input_len 46
max_target_len 43

 ソース文の最大系列長は46、ターゲット文の最大系列長は43であった。

Make Config Class for Architecture and Training

 ハイパーパラメータを保持するConfigクラスを設定する。このクラスのインスタンスは、以降すべてのクラスおよび関数で共有される。特に、最大系列長max_lenはソース文、ターゲット文両方の最大系列長である46に、単語の埋め込み次元であるembed_dim( d_{model} )は256、アテンションヘッド数であるnum_headsは8、queryの次元q_dim( d_q)とkey/valueの次元v_dim( d_k)は両方ともembed_dim ÷ num_heads = 32とした。

class Config:
  # Architecture Config
  src_vocab_size: int = 7853
  trg_vocab_size: int = 5893
  pad_idx: int = 1
  bos_idx: int = 2
  eos_idx: int = 3
  max_len: int = 46
  embed_dim: int = 256
  id_dtype: Any = jnp.int16
  dtype: Any = jnp.float32
  num_heads: int = 8
  num_layers: int = 3
  q_dim: int = embed_dim // num_heads
  v_dim: int = embed_dim // num_heads
  ff_dim: int = 512
  dropout_rate: float = 0.1
  # Training Config
  special_idxes: List[int] = [1,2,3]
  special_tokens: List[str] = ['<bos>', '<eos>', '<pad>']
  seed: int = 0
  batch_size: int = 128
  learning_rate: float = 0.0005
  warmup_steps: int = 100
  num_epochs: int = 150
  valid_every_epochs: int = 2
  save_ckpt_every_epochs: int = 1
  restore_checkpoints: bool = True
  ckpt_prefix: str = 'translation_ckpt_'
  ckpt_dir: str = '/content/drive/My Drive/checkpoints/translation’

Implement Transformer Model

 それではTransformerの実装をしていく。今ではTransformerの詳細を解説するわかりやすい記事は山のようにあるので、ここでは主に実装上の観点から解説する。Transformerの全体像は、おなじみの親の顔より見た以下の図である。

Transformerが行うのは基本的に入力系列を受け取り出力系列出力する系列変換(seq2seq)である。翻訳タスクでは、ソース文を受け取りターゲット文を出力するように学習する。

学習時には、図のようにまず[バッチサイズ、系列長]の大きさのソース文の二次元配列を入力し、埋め込みによって[バッチサイズ、系列長、埋め込み次元]のテンソルにした後、N層のEncoderに入力する。Encoderの中にはMulti-Head Attention, Add&Norm, FeedForwardなどのモジュールが存在し、その中で様々な処理を行うが、入出力の形状はすべて[バッチサイズ、系列長、埋め込み次元]のままである。この形状でEncoderから出力されたテンソルは、ソース文の情報を含んだ記憶のようなものであることからmemoryという。次に、[バッチサイズ、系列長]の大きさのターゲット文の二次元配列を入力する。この時、ターゲット文は<bos>トークンを最初のトークンとし、<eos>トークンを削除したものを入力する。その後、同じく埋め込みによって[バッチサイズ、系列長、埋め込み次元]のテンソルにし、N層のDecoderに入力する。DecoderにはMasked Multi-Head Attention, Add&Norm, Multi-Head Attention, Feed Forwardなどのモジュールが存在し、ここでも各モジュールの入出力の形状はすべて[バッチサイズ、系列長、埋め込み次元]のままである。また、DecoderのMulti-Head Attentionは、ソース文の情報を含むmemoryとターゲットを用いて計算を行うため、Source-Target Attentionとも呼ばれる。Decoderの出力は、Linear層によって[バッチサイズ、系列長、ターゲットの辞書サイズ]のテンソルに変換しsoftmaxを通し、最終次元が正規化された分布となって、Transformerの出力(logits)となる。

その後、loss関数内でlogitsを用い、今度は<bos>を削除し文の最後のトークンの後に<eos>を付けたターゲット文を教師ラベルとして、cross entropy lossを計算する。lossの出力は実数のスカラーなので、Transformerのloss関数はソース文、ターゲット文、パラメータを変数とする多変数実数値関数となる。これをパラメータに対して微分してパラメータ空間上での勾配を求め、lossが下がる方向にパラメータを更新する。ここで、Decoderでは<bos>で始まり<eos>がないターゲット文を入力し、lossの教師ラベルでは<bos>がなく<eos>で終わるターゲット文を用いた。例えば先ほど見た訓練データの文の場合だと、以下のようになる。

入力のターゲット文:['<bos>', 'three', 'young', 'people', 'planting', 'flowers', 'and', 'covering', 'the', 'area', 'with', 'a', 'tarp', '.', '', ...]

教師ラベルのターゲット文:['three', 'young', 'people', 'planting', 'flowers', 'and', 'covering', 'the', 'area', 'with', 'a', 'tarp', '.', '<eos>', '', ....]

入力のターゲット文と比べ、教師ラベルのターゲット文は1トークンだけ先に進んでいることがわかる。そして、lossを小さくするということは、logitsの各系列位置におけるターゲット辞書の大きさをもつ配列の最大となるものの位置が、教師データの同じ系列位置のIDとなるように学習するということである。つまり、memoryと入力のターゲット文を用い、下図のように各系列位置ごとに次の位置のトークンを出力するように学習する。特に、<bos>トークンが入力されたら最初のトークンを、文の終わりには<eos>トークンを出力するように学習する。

 それでは、以下でTransformerを定義していく。

Positional Encoding

 Transformerには再帰も畳み込みもなく系列の時系列を捉えることができない。そのため、系列の順序に関する情報を付加するために、位置埋め込みを行う。埋め込み後の入力文と同じ形状[バッチサイズ、系列長、埋め込み次元]の位置埋め込みテンソル作製し、埋め込み後の入力文に足す。位置埋め込みは全てのバッチで同じであり、系列位置を pos、埋め込み次元での偶数位置 2i、奇数位置を 2i+1とし、以下の式で表される。

 PE_{(pos, 2i)} = sin \left( pos/10000^{2i/d_{model}} \right) \\
PE_{(pos, 2i+1)} = cos \left( pos/10000^{2i/d_{model}} \right)

この工程を実装すると以下のようになる。途中少々複雑な計算をしているが、自分でダミーデータでどのような計算が行われているのか見てみるといいかもしれない。

class PositionalEncoder(nn.Module):
  config: Config
  """Adds sinusoidal positional embeddings to the inputs.
     
  Attribues:
     config: config class containing hyperparameters.
  """
  @nn.compact
  def __call__(self, 
               x: jnp.array
               ) -> jnp.array:
    """Applys PositionalEncoder Module.
     
    Args:
      x: inputs of shape [batch_size, length, features].

    Returns: 
      The positional embedded inputs of the same shape of the inputs.
    """
    # [Batch, SeqLen, EmbedDim]
    assert x.ndim == 3
    config = self.config

    batch_size, seq_len = x.shape[0], x.shape[1]
    # [Batch, SeqLen, EmbedDim]
    pe = jnp.empty((batch_size, seq_len, config.embed_dim))
    # [1, SeqLen]
    position = jnp.arange(0, seq_len, dtype=config.dtype)[jnp.newaxis, :]       
    # [1, EmbedDim/2]
    div_term = jnp.exp(jnp.arange(0, config.embed_dim, 2, dtype=config.dtype)
         * (-lax.log(10000.0) / config.embed_dim))[jnp.newaxis, :]              
    # [SeqLen, EmbedDim/2]
    radians = jnp.einsum('ij,kl->jl', position, div_term)                       
    pe = pe.at[:, :, 0::2].set(jnp.sin(radians))
    pe = pe.at[:, :, 1::2].set(jnp.cos(radians))
    x = x + pe
    return x.astype(config.dtype)

Multi-Head Attention

 query, key, valueを受け取り、headの数だけ新たなquery, key, valueのペアを作製してScaled Dot Product Attentionを計算し、Concatしてつなげ、最後に線形変換を行うMulti-Head Attentionを実装する。式で表すと以下のようになる。

 \rm{Attention} (Q, K, V) = \rm{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V \\
\rm{MultiHead} (Q, K, V) = \rm{Concat} ( \rm{head_i}, ..., \rm{head_h} ) W^O \\
\qquad where \quad \rm{head_i} = Attention \left( QW^Q_i, KW^K_i, VW^V_i \right)

ここで、 Q, K, Vそれぞれquery, key, value d_kはqueryおよびkeyの次元、 head_iはそれぞれのattention headの出力、 hはヘッドの数、 W_Q, W_K, W_V, W_Oはそれぞれquery, key, valueを線形変換するパラメータと、最後に各headからの出力をConcatしたものを線形変換するパラメータである。後々行列積を計算するためにq, kの次元 d_kは同じであり、さらに同様の理由でkeyの系列長とvalueの系列長も同じである必要がある。(実際にはkeyとvalueは同じテンソルであるため、keyとvalueの形状は常に等しくなるはずであり、あまり気にする必要はない。)

今回の実装では、[バッチサイズ、q/k/vの系列長、埋め込み次元]の形状のq, k, vを受け取り、それぞれ線形変換により最終次元が変換され、[バッチサイズ、q/k/vの系列長、ヘッドの数 × q/k/vの次元]の形状になる。今回はConfigクラス内で設定したように、 d_k = d_v = d_{model} / hであるため、最終次元は h × d_k = h × d_v = d_{model} であり、線形変換の前後で形状は変化しない。そして、reshapeして[バッチサイズ、q/k/vの系列長、ヘッドの数、q/k/vの次元]の形状に変形し、各系列ごとに複数のヘッド(今回の設定では8個)を持ち、ヘッドごとにq/k/vの次元を持つように変形する。その後、einsumによってqueryとkeyの行列積を計算し、 \sqrt{d_k}で割ってscaled dot product attentionを計算する。アインシュタインの縮約記法による計算が少し分かりづらいが、[バッチサイズ、q/kの系列長、ヘッドの数、q/kの次元]という形状だったqueryとkeyを次元を入れ替えて、[バッチサイズ、ヘッドの数、q/kの系列長、q/kの次元]にし、keyに関してはさらに最後の二つの次元を転置して[バッチサイズ、ヘッドの数、kの次元, kの系列長]として、[バッチサイズ、ヘッドの数、qの系列長、qの次元]のqueryと[バッチサイズ、ヘッドの数、kの次元, kの系列長]のkeyの最後の二つの次元に関して行列積を計算するとことで、[バッチサイズ、ヘッドの数、qの系列長, kの系列長]の形状のattention weightを計算し、 \sqrt{d_k}で割っているのと同じことをしている。次に、[バッチサイズ、1、qの系列長、kの系列長]の形状のmask(バッチサイズの数だけ、各系列ごとにマスクがある。)により、各系列ごとにすべてのヘッドに対して同じmaskを用いて[qの系列長、kの系列長]の形状のマスクをし、keyのマスク位置を-jnp.infに変更する。その後にsoftmaxを通して各attention weightをkey方向に0~1に正規化する。ここで、-jnp.infになっている箇所は0になるため、マスク位置に重みがなくなり、マスクされたトークンと他のトークンとの関連は学習されなくなる他、虐伝播においてkey, query, valueのいずれにおいてもマスク位置に勾配が流れなくなる。

class MultiHeadAttention(nn.Module):
  config: Config
  """Multi-head dot-product attention.
     
  Attribues:
     config: config class hyperparameters.
  """

  @nn.compact
  def __call__(self, 
               q : jnp.ndarray,
               k : jnp.ndarray,
               v : jnp.ndarray, 
               mask : jnp.ndarray = None
               ) -> Tuple[jnp.array, jnp.array]:
    """Applys MultiHeadAttention Module.
     
    Args:
      q: query inputs of shape [batch_size, query_length, features].
      k: key inputs of shape [batch_size, key/value_length, features].
      v: value inputs of shape [batch_size, key/value_length, features].
      mask: attention mask of shape [batch_size, 1, query_length, key/value_length].
       
    Returns: 
      The output of shape [batch_size, query_length, features], 
      and an attention matrix of shape [batch_size, num_heads, query_length, key/value_length].
    """

    # [Batch, SeqLen_q, EmbedDim]    
    assert q.ndim == 3                                                          
    # [Batch, SeqLen_k, EmbedDim]
    assert k.ndim == 3                                                          
    # [Batch, SeqLen_v, EmbedDim]
    assert v.ndim == 3                                                          
    # Same batch size
    assert q.shape[0] == k.shape[0] == v.shape[0]                               
    # SeqLen_k = SeqLen_v
    assert k.shape[1] == v.shape[1]                                             
    if mask is not None:
      # [Batch, 1, SeqLen_q, SeqLen_k]
      assert mask.ndim == 4                                                     
    config = self.config

    q_seq_len, k_seq_len = q.shape[1], k.shape[1]

    # [Batch, SeqLen_q, Head * Dim_k]
    q = nn.Dense(config.num_heads * config.k_dim)(q)                            
    # [Batch, SeqLen_k, Head * Dim_k]
    k = nn.Dense(config.num_heads * config.k_dim)(k)                            
    # [Batch, SeqLen_k, Head * Dim_v]
    v = nn.Dense(config.num_heads * config.v_dim)(v)                            

    # [Batch, SeqLen_q, Head, Dim_k]
    q = q.reshape(-1, q_seq_len, config.num_heads, config.k_dim)                
    # [Batch, SeqLen_k, Head, Dim_k]
    k = k.reshape(-1, k_seq_len, config.num_heads, config.k_dim)                
    # [Batch, SeqLen_k, Head, Dim_v]
    v = v.reshape(-1, k_seq_len, config.num_heads, config.v_dim)                

    # [Batch, Head, SeqLen_q, SeqLen_k]
    attention = (jnp.einsum('...qhd,...khd->...hqk', q, k) 
                                / jnp.sqrt(config.v_dim)).astype(config.dtype)  
    # Change the masked position to -jnp.inf.
    if mask is not None:                                                         
      attention = jnp.where(mask, attention, -jnp.inf)                          
    # [Batch, Head, SeqLen_q, SeqLen_k]
    attention = nn.softmax(attention, axis=-1).astype(config.dtype)             
    # [Batch, SeqLen_q, Head, Dim_v]  
    values = jnp.einsum('...hqk,...khv->...qhv', attention, v)                  
    # [Batch, SeqLen_q, Head × Dim_v (=EmbedDim)]
    values = values.reshape(-1, q_seq_len, config.num_heads * config.v_dim)     
    # [Batch, SeqLen_q, EmbedDim] 
    out = nn.Dense(config.embed_dim, dtype=config.dtype)(values)                

    return out.astype(config.dtype), attention.astype(config.dtype)

 今回は通常の実装通り、key方向にのみマスク位置し、-jnp.infに変更するが、Flaxではquery, key方向の両方にこの時点でマスクし、マスク位置のattentionを-jnp.infではなく-jnp.finfo(config.dtype).minに変更することが多い。このようにすると、query方向にマスクされていない行のkey方向のマスク位置は、softmaxを通すことで-jnp.finfo(config.dtype).minより小さくなり表現可能な値より小さくなることで0になる。一方で、query方向にマスクがある場合、マスク位置の行が全て-jnp.finfo(config.dtype).minになってしまい、softmaxを通すとuniformになってしまう。すると、それらは全て定数であるために逆伝播においてqueryおよびkey方向に勾配が流れませんが、value方向に勾配が流れてしまう。しかし、実際にはloss関数においてマスク位置は0,それ以外は1のweightを適用し、query方向に流れてきたマスク位置のlossをゼロにするため、マスク位置にlossは流れない。また、今回の実装でもこの時点はkeyに対してのみマスクをし、loss関数内でweightを掛けてquery方向にマスクを行っている。FlaxでこのようにQuery方向にもマスクをする理由は、おそらく近年提案されてきた様々で複雑な形状のAttention maskを、簡単に適用できるようにするためだと考えられる。しかしながら、今回のようにquery方向に一様にマスクする通常のAttention maskの場合には、今回の実装のようにquery方向のマスクをしなくても同じである。ここら辺に関しては、以下でも議論をしている。

github.com

Feed Forward Network

 線形変換した後、活性化関数ReLuを適用し、さらに線形変換するFeed Forward Networkを実装する。

 \rm{FFN}(x) = max(0, xW_1, + b_1)W_2 + b_2
class FeedForward(nn.Module):
  config: Config
  """Feed Forward Network.
     
  Attribues:
     config: config class containing hyperparameters.
  """

  @nn.compact
  def __call__(self, 
               x: jnp.array,
               deterministic: bool
               ) -> jnp.array:
    """Applys FeedForward module.
     
    Args:
      x: inputs of shape [batch_size, length, features].
      deterministic: parameter for nn.Dropout. if true, it masks and scales the inputs. during training, it should be False. Otherwise True.

    Returns: 
      The output of the same shape of the inputs.
    """
    
    assert x.ndim == 3                                                          
    # [Batch, SeqLen, EmbedDim]
    config = self.config

    # Dense Layer
    x = nn.Dense(config.ff_dim, dtype=config.dtype)(x)                          
    # ReLu
    x = nn.relu(x)    
    # Dropout                                                          
    x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic)    
    # Dense Layer
    x = nn.Dense(config.embed_dim, dtype=config.dtype)(x)                       
    # Dropout
    x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic)    

    return x.astype(config.dtype)

Transformer Encoder Layer

 以下の図で表されるTransformer Encoder Layerを実装する。

class TransformerEncoderLayer(nn.Module):
  config: Config
  """Transformer encoder layer (Encoder 1D block).
     
  Attribues:
     config: config class containing hyperparameters.
  """

  @nn.compact
  def __call__(self, 
              x: jnp.array,
              encoder_mask: jnp.array,
              deterministic: bool
              ) -> Tuple[jnp.array, jnp.array]:
    """Applys TransformerEncoderLayer module.
     
    Args:
      x: inputs of shape [batch_size, length, features].
      encoder_mask: attention mask for Self-Attention of shape [batch_size, 1, length, length].
      deterministic: parameter for nn.Dropout. if true, it mask and scale the inputs. during training, it should be False. Otherwise True.

    Returns: 
      The outputs of the same shape of the inputs,
      and an attention matrix in Self-Attention of shape [batch_size, num_heads, length, length].
    """

    # [Batch, SeqLen, EmbedDim]    
    assert x.ndim == 3                                                          
    # [Batch, 1, SeqLen_q, SeqLen_k]
    assert encoder_mask.ndim == 4                                               
    config = self.config

    # Residual
    res = x         
    # Self-Attention                                                            
    x, attention = MultiHeadAttention(config)(x, x, x,
                                              mask=encoder_mask)                
    # Dropout
    x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic)    
    # Add & Norm
    x = nn.LayerNorm(dtype=config.dtype)(res + x)                               

    # Residual   
    res = x                                                                     
    # Feed Forward Network
    x = FeedForward(config)(x, deterministic)                                   
    # Add & Norm
    x = nn.LayerNorm(dtype=config.dtype)(res + x)                               
    
    return x.astype(config.dtype), attention.astype(config.dtype)

Transformer Decoder Layer

 以下の図で表されるTransformer Decoder Layerを実装する。

class TransformerDecoderLayer(nn.Module):
  config: Config
  """Transformer decoder layer (Encoder Decoder 1D block).
     
  Attribues:
     config: config class containing hyperparameters.
  """

  @nn.compact
  def __call__(self,
              x: jnp.array,
              memory: jnp.array,
              decoder_mask: jnp.array,
              encoder_decoder_mask: jnp.array,
              deterministic: bool
              )-> Tuple[jnp.array, jnp.array, jnp.array]:
    """Applys TransformerDecoderLayer module.
     
    Args:
      x: inputs of shape [batch_size, x_length, features].
      memory: encoded sources from Transformer Encoder of shape [batch_size, memory_length, features].
      decoder_mask: attention mask for Self-Attention of shape [batch_size, 1, x_length, x_length].
      encoder_decoder_mask: attention mask for Source-Target Attention of shape [batch_size, 1, x_length, memory_length].
      deterministic: parameter for nn.Dropout. if true, it mask and scale the inputs. during training, it should be False. Otherwise True.

    Returns: 
      The outputs of the same shape of the inputs,
      an attention matrix in Self-Attention of shape [batch_size, num_heads, x_length, x_length],
      and an attention matrix in Source-Target Attention of shape [batch_size, num_heads, x_length, memory_length].
    """

    # [Batch, SeqLen, EmbedDim]      
    assert x.ndim == 3                                                           
    # [Batch, SeqLen, EmbedDim]                       
    assert memory.ndim == 3                                                     
    # [Batch, 1, SeqLen_q, SeqLen_k]
    assert decoder_mask.ndim == 4                                               
    # [Batch, 1, SeqLen_q, SeqLen_k]
    assert encoder_decoder_mask.ndim == 4                                       
    config = self.config

    # Residual
    res = x                                                                     
    # Self-Attention
    x, self_attention = MultiHeadAttention(config)(x, x, x,
                                                   mask=decoder_mask)           
    # Dropout
    x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic)    
    # Add & Norm
    x = nn.LayerNorm(dtype=config.dtype)(res + x)                               

    # Residual
    res = x                                                                     
    # Source-Target Attention
    x, src_trg_attention = MultiHeadAttention(config)(x, memory, memory,
                                                      mask=encoder_decoder_mask)
    # Dropout
    x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic)    
    # Add & Norm
    x = nn.LayerNorm(dtype=config.dtype)(res + x)                               

    # Residual
    res = x                                                                     
    # Feed Forward Network
    x = FeedForward(config)(x, deterministic)                                   
    # Add & Norm
    x = nn.LayerNorm(dtype=config.dtype)(res + x)                               

    return x.astype(config.dtype), self_attention.astype(config.dtype)

Transformer Encoder

 以下の図で表されるTransformer Encoderを実装する。

class TransformerEncoder(nn.Module):
  config: Config
  """Transformer Encoder.
     
  Attribues:
     config: config class containing hyperparameters.
  """
  @nn.compact
  def __call__(self,
               src: jnp.array,
               encoder_mask: jnp.array,
               deterministic: bool,
               return_attn: bool = False
               ) -> Union[jnp.array, Tuple[jnp.array, List[jnp.array]]]:
    """Applys TransformerEncoder module.
     
    Args:
      src: sources of shape [batch_size, length].
      encoder_mask: attention mask for Self-Attention of shape [batch_size, 1, length, length].
      deterministic: parameter for nn.Dropout. if true, it mask and scale the inputs. during training, it should be False. Otherwise True.
      return_attn: if true, returns self-attention matrixes.

    Returns: 
      If return_attn is True, 
        the encoded sources of shape [batch_size, length, features],
        and list of attention matrix in Self-Attention for the number of layers.
      else,
        the encoded sources.
    """

    # [Batch, SeqLen]    
    assert src.ndim == 2                                                        
    # [Batch, 1, SeqLen_q, SeqLen_k]
    assert encoder_mask.ndim == 4                                              
    config = self.config

    # Embedding
    x = nn.Embed(num_embeddings=config.src_vocab_size, 
                 features=config.embed_dim, 
                 dtype=config.id_dtype,
                 param_dtype=config.dtype, 
                 embedding_init=nn.initializers.normal(stddev=1.0))(src)        
    # Positinal Encoding  
    x = PositionalEncoder(config)(x)                                            
    # Dropout
    x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic)    

    # Encoder Layer
    attention_list = []
    for layer in [TransformerEncoderLayer(config) for _ in range(config.num_layers)]:
      x, attention = layer(x, 
                           encoder_mask, 
                           deterministic)                                       
      attention_list.append(attention)
    # Layer Normalization
    memory = nn.LayerNorm(dtype=config.dtype)(x)                                

    if return_attn:
      return memory.astype(config.dtype), attention_list
    return memory

Transformer Decoder

 以下の図で表されるTransformer Decoderを実装する。

class TransformerDecoder(nn.Module):
  config: Config
  """Transformer Decoder.
     
  Attribues:
     config: config class containing hyperparameters.
  """
  
  @nn.compact
  def __call__(self, 
               memory: jnp.array,
               trg: jnp.array,
               decoder_mask: jnp.array,
               encoder_decoder_mask: jnp.array,
               deterministic: bool,
               return_attn: bool = False
               ) -> Union[jnp.array, Tuple[jnp.array, List[jnp.array], List[jnp.array]]]:
    """Applys TransformerDecoder module.
     
    Args:
      memory: encoded sources from Transformer Encoder of shape [batch_size, src_length, features].
      trg: targets of shape [batch_size, trg_length].
      decoder_mask: attention mask for Self-Attention of shape [batch_size, 1, trg_length, trg_length].
      encoder_decoder_mask: attention mask for Source-Target Attention of shape [batch_size, 1, trg_length, src_length].
      deterministic: parameter for nn.Dropout. if true, it mask and scale the inputs. during training, it should be False. Otherwise True.
      return_attn: if true, returns Self-Attention and Source-Target-Attention matrixes.

    Returns: 
      If return_attn is True, 
        the logits of shape [batch_size, trg_length, target_vocab_size],
        list of attention matrix in Self-Attention for the number of layers,
        and list of attention matrix in Source-Target-Attention for the number of layers.
      else,
        the logits.
    """

    # [Batch, SeqLen, EmbedDim]    
    assert memory.ndim == 3                                                     
    # [Batch, SeqLen]
    assert trg.ndim == 2                                                        
    # [Batch, 1, SeqLen_q, SeqLen_k]
    assert decoder_mask.ndim == 4                                               
    # [Batch, 1, SeqLen_q, SeqLen_k]
    assert encoder_decoder_mask.ndim == 4                                       
    config = self.config

    # Embedding
    x = nn.Embed(num_embeddings=config.trg_vocab_size, 
                 features=config.embed_dim, 
                 dtype=config.id_dtype,
                 param_dtype=config.dtype, 
                 embedding_init=nn.initializers.normal(stddev=1.0))(trg)        
    # Positinal Encoding
    x = PositionalEncoder(config)(x)                                            
    # Dropout
    x = nn.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic)    

    # Decoder Layer
    self_attention_list, src_trg_attention_list = [], []
    for layer in [TransformerDecoderLayer(config) for _ in range(config.num_layers)]:
      x, self_attention, src_trg_attention = layer(x, 
                                                   memory, 
                                                   decoder_mask, 
                                                   encoder_decoder_mask, 
                                                   deterministic)               
      self_attention_list.append(self_attention)
      src_trg_attention_list.append(src_trg_attention)

    # Layer Normalization
    x = nn.LayerNorm(dtype=config.dtype)(x)                                     
    # Dense Layer to reform the embed size to vocab size
    logits = nn.Dense(config.trg_vocab_size, dtype=config.dtype)(x)             

    if return_attn:
      return logits.astype(config.dtype), self_attention_list, src_trg_attention_list
    return logits.astype(config.dtype) 

Transformer

 最後に、以下の図で表されるTransformerを実装する。

class Transformer(nn.Module):
  config: Config
  """Transformer.
     
  Attribues:
     config: config class containing hyperparameters.
  """  
  
  def setup(self):
    config = self.config
    self.Encoder = TransformerEncoder(config)
    self.Decoder = TransformerDecoder(config)

  def __call__(self,
               src: jnp.array,
               trg: jnp.array,
               train: bool = False,
               ) -> jnp.array:
    """Applys Transformer module.
     
    Args:
      src: sources of shape [batch_size, src_length].
      trg: targets of shape [batch_size, trg_length].
      train: To train, it should be set True, otherwise False.

    Returns: 
      the logits of shape [batch_size, trg_length, target_vocab_size].
    """

    assert src.ndim == 2
    assert trg.ndim == 2
    config = self.config
    
    memory = self.encode(src, train=train)
    logits = self.decode(trg, src, memory, train=train)
    return logits.astype(self.config.dtype)
  
  def encode(self,
             src: jnp.array,
             train: bool = False,
             return_attn: bool = False
             ) -> Union[jnp.array, Tuple[jnp.array, List[jnp.array]]]:
    """Encode sources with Transformer Encoder.
     
    Args:
      src: sources of shape [batch_size, length].
      train: To train, it should be set True, otherwise False.
      return_attn: if true, returns self-attention matrixes.

    Returns: 
      If return_attn is True, 
        the encoded sources of shape [batch_size, length, features],
        and list of attention matrix in Self-Attention for the number of layers.
      else,
        the encoded sources.
    """
    
    assert src.ndim == 2
    config = self.config

    # [Batch, 1, SeqLen_q, SeqLen_k]
    encoder_mask = nn.make_attention_mask(
        jnp.ones_like(src),
        src != config.pad_idx, 
        dtype=bool)                                                             
                                                  
    if return_attn:
      memory, encoder_attention_list = self.Encoder(src, encoder_mask, not train, return_attn=return_attn)
      return memory, encoder_attention_list
    else:
      memory = self.Encoder(src, encoder_mask, not train, return_attn=return_attn)
      return memory
  
  def decode(self,
             trg: jnp.array,
             src: jnp.array, # only for making mask
             memory: jnp.array,
             train: bool = False,
             return_attn: bool = False
             ) -> Union[jnp.array, Tuple[jnp.array, List[jnp.array], List[jnp.array]]]:
    """Decode targets with Transformer Decoder.
     
    Args:
      trg: targets of shape [batch_size, trg_length].
      src: sources of shape [batch_size, src_length].
      memory: encoded sources from Transformer Encoder of shape [batch_size, src_length, features].
      train: To train, it should be set True, otherwise False.
      return_attn: if true, returns Self-Attention and Source-Target-Attention matrixes.

    Returns: 
      If return_attn is True, 
        the logits of shape [batch_size, trg_length, target_vocab_size],
        list of attention matrix in Self-Attention for the number of layers,
        and list of attention matrix in Source-Target-Attention for the number of layers.
      else,
        the logits.
    """
    
    assert trg.ndim == 2
    config = self.config
    
    # [Batch, 1, SeqLen_q, SeqLen_k]
    decoder_mask = nn.combine_masks(
        nn.make_attention_mask(
            jnp.ones_like(trg),
            trg != config.pad_idx, 
            dtype=bool),
        nn.make_causal_mask(trg, 
                            dtype=bool)
    )
                       
    # [Batch, 1, SeqLen_q, SeqLen_k]                                                    
    encoder_decoder_mask = nn.make_attention_mask(                 
        jnp.ones_like(trg),                       
        src != config.pad_idx,                        
        dtype=bool                              
    )                                                                           

    if return_attn:
      logits, decoder_attention_list, src_trg_attention_list = self.Decoder(memory, trg, decoder_mask, encoder_decoder_mask, not train, return_attn=return_attn)
      return logits, decoder_attention_list, src_trg_attention_list
    else:
      logits = self.Decoder(memory, trg, decoder_mask, encoder_decoder_mask, not train, return_attn=return_attn)
      return logits

前にも述べたようにここでは以下のようにmaskを作製する際にquery方向は全てTrueとしており、maskしていない。

    encoder_mask = nn.make_attention_mask(
        jnp.ones_like(src),
        src != config.pad_idx, 
        dtype=bool)  

Define Train, Valid and Translate Functions

 Transfomerのモデルを定義したところで、学習と翻訳に必要な関数を定義する。ここで、lossを計算するcompute_weighted_cross_entropy関数、train_step関数、valid_step関数をjitコンパイルする。jitコンパイルされた関数に、異なる形状が入力されるとそのたびにコンパイルされbytecodeが生成されて時間がかかってしまう。(おなじ形状の入力が来た場合は使いま回される。)今回は最大系列長が短いこともあるため、ソース文とターゲット文を含めすべての入力をさらにpaddingして形状を[バッチサイズ, 最大系列長]に揃えることで高速化を行っている。

さらに、翻訳時には、翻訳したいソース文を1文入力した場合、[1, ソース文の系列長]の形状でEncoderに入力し、[1, ソース文の系列長, 埋め込み次元]のmemoryが出力される。例えば、上の例だと[ '<bos>', 'drei', ... , 'ab', '.', '<eos>', '<pad>', ... '<pad>' ]がID列として入力され、Decoderには、最初に<bos>トークンだけを[1, 1]の形状で入力し、memoryも用いながら計算処理が行われ、[1, 1, ターゲット文の辞書サイズ]の形状のlogitsが出力される。ここでターゲット文の辞書サイズの最大位置が、予測された次のトークンのIDとなる。上の例だと、きちんと訓練データを学習していたとしたら、’three'のIDとなっているはずである。今回はgreedy searchというサンプリング手法を用いて最大位置を次のトークンの予測結果として順次採用していくため、次はDecoderの入力を[ '<bos>', 'three' ]のID列とし、さらに次のトークンを予測する。次の予測結果は'young'になっているはずである。このように自己回帰的に次のトークンをトークンが出力されるまで予測することを繰り返し、翻訳文を生成する。今回の実装では、translator関数は一文ごとに翻訳するのではだけでなく、バッチ処理に対応しており、複数文を受け取って同時に翻訳することができるようになっている。

def padding(
    array: jnp.array,
    config: Config
    ) -> jnp.array:
  """Takes a 2D array and adds pads to max length on the last dimension.
   
  Args:
    array: 2d array of shape [batch_size, length] whose length <= max length.
    config: config class containing hyperparameters.

  Returns:
    A padded array of shape [batch_size, max_length].
  """

  # [Batch, SeqLen]
  assert array.ndim == 2                                                        
  batch_size, seqlen = array.shape[0], array.shape[1]
  assert seqlen <= config.max_len

  if seqlen < config.max_len:
    # [Batch, MaxLen - SeqLen]
    pads = jnp.ones((batch_size, config.max_len - seqlen), 
                   dtype=config.id_dtype) * config.pad_idx      
    # [Batch, MaxLen]                
    padded_array = jnp.concatenate((array, pads), axis=-1)                      

  else:
    padded_array = array
  
  return padded_array

@partial(jax.jit, static_argnums=(3,))
def compute_weighted_cross_entropy(
    logits: jnp.array,
    trg: jnp.array,
    weight: jnp.array,
    label_smoothing: float =0.0
    ) -> jnp.array:
  """Calculate weighted cross entropy.
   
  Args:
    logits: output from Transformer of shape [batch_size, length, target_vocab_size].
    trg: targets of shape [batch_size, length].
    weight: boolean array of shape [batch_size, length]. the pads positions in targets is 0. otherwise, 1.
    label_smoothing: label smoothing constant.

  Returns:
    Scalar loss.
  """

  # [Batch, SeqLen, VocabSize]
  assert logits.ndim == 3                                                       
  # [Batch, SeqLen]
  assert trg.ndim == 2                                                          
  # [Batch, SeqLen]
  assert weight.ndim == 2                                                       
  
  batch_size, vocab_size = logits.shape[0], logits.shape[2]
  confidence = 1.0 - label_smoothing
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
  normalizing_constant = -(
      confidence * jnp.log(confidence) +
      (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
  soft_target = common_utils.onehot(
      trg, vocab_size, on_value=confidence, off_value=low_confidence)
  
  loss = -jnp.sum(soft_target * nn.log_softmax(logits), axis=-1)
  loss = loss - normalizing_constant

  # normalize by batch_size
  normalizing_factor = batch_size                                               
  loss = loss * weight
  loss = loss.sum() / normalizing_factor
  
  return loss

@partial(jax.jit, static_argnums=(1,5))
def train_step(
    state: train_state.TrainState,
    model: Transformer,
    src: jnp.array,
    trg: jnp.array,
    dropout_rng: jax.random.PRNGKey,
    config: Config
    ) -> Tuple[train_state.TrainState, jnp.array]:
  """Runs a training step.
  In order to minimize the number of jit compile and accelerate, 
  this step takes padded src and trg that has always same shapes. 

  Args:
    state: training state.
    model: Transformer model.
    src: padded sources of shape [batch_size, max_length].
    trg: padded targets of shape [batch_size, max_length].
    dropout_rng: PRNGKey for dropout.
    config: config class containing hyperparameters.
  
  Returns:
    new_state: updated training state.
    loss: scalar loss. 
  """

  trg_for_input = jnp.where(trg == config.eos_idx, config.pad_idx, trg)[:, :-1]
  trg_for_loss = trg[:, 1:]
  weight = jnp.where(trg_for_input == config.pad_idx, 0, 1).astype(config.id_dtype)

  def loss_fn(params):
    logits = model.apply(
      {"params": params},
      src,
      trg_for_input,
      train = True,
      rngs={"dropout": dropout_rng})
    loss = compute_weighted_cross_entropy(logits, trg_for_loss, weight)
    return loss
    
  loss, grads = jax.value_and_grad(loss_fn)(state.params)
  new_state = state.apply_gradients(grads=grads)
  return new_state, loss

def train(
    state: train_state.TrainState,
    model: Transformer,
    train_iterator,
    config,
    dropout_rng
    ) -> Tuple[train_state.TrainState, float]:
  """Runs a training loop.

  Args:
    state: training state.
    model: transformer model.
    train_iterator: iterator for training.
    config: config class containing hyperparameters.
    dropout_rng: PRNGKey for dropout.
  
  Returns:
    state: updated training state.
    loss: average loss of 1 epoch. 
  """

  loss_history = []
  for i, batch in enumerate(train_iterator):
    src, trg = padding(jnp.asarray(batch.src), config), padding(jnp.asarray(batch.trg), config)
    dropout_rng = jax.random.fold_in(dropout_rng, state.step)
    state, loss = train_step(state, model, src, trg, dropout_rng, config)
    loss_history.append(loss)
  
  train_loss = sum(loss_history) / len(loss_history)
  return state, float(train_loss)

@partial(jax.jit, static_argnums=(1,4))
def valid_step(
    state: train_state.TrainState,
    model: Transformer,
    src: jnp.array,
    trg: jnp.array,
    config: Config
    )-> jnp.array:
  """Runs a validation step.
  In order to minimize the number of jit compile and accelerate, 
  this step takes padded src and trg that has always same shapes. 

  Args:
    state: training state.
    model: Transformer model.
    src: padded sources of shape [batch_size, max_length].
    trg: padded targets of shape [batch_size, max_length].
    config: config class containing hyperparameters.
  
  Returns:
    loss: scalar loss. 
  """

  trg_for_input = jnp.where(trg == config.eos_idx, config.pad_idx, trg)[:, :-1]
  trg_for_loss = trg[:, 1:]
  weight = jnp.where(trg_for_input != config.pad_idx, 1, 0).astype(config.id_dtype)

  logits = model.apply(
                      {"params": state.params}, 
                      src, 
                      trg_for_input, 
                      train = False)
  loss = compute_weighted_cross_entropy(logits, trg_for_loss, weight)
    
  return loss

def valid(
    state: train_state.TrainState,
    model: Transformer,
    valid_iterator,
    config: Config
    ) -> float:
  """Runs a validation loop.

  Args:
    state: training state.
    model: transformer model.
    valid_iterator: iterator for validation.
    config: config class containing hyperparameters.
  
  Returns:
    loss: average loss of 1 epoch. 
  """

  loss_history = []
  for _, batch in enumerate(valid_iterator):
    src, trg = padding(jnp.asarray(batch.src), config), padding(jnp.asarray(batch.trg), config)
    loss = valid_step(state, model, src, trg, config)
    loss_history.append(loss)
  
  valid_loss = sum(loss_history) / len(loss_history)

  return float(valid_loss)

def translator(
    src: jnp.array,
    state: train_state.TrainState,
    model: Transformer,
    config: Config,
    return_attn: bool = False
    ) -> Union[List[List[str]], Tuple[List[List[str]], List[jnp.array], List[List[jnp.array]], List[List[jnp.array]]]]:
  """Translate batch sources by greedy search.

  Args:
    src: sources of shape [Batch, length]
    state: training state.
    model: transformer model.
    config: config class containing hyperparameters.
    return_attn: if true, returns Self-Attention and Source-Target-Attention matrixes both from encoder and decoder.
  
    Returns: 
      If return_attn is True, 
        the list of translatoin list of english tokens of shape [batch_size, translation length (<=max length)],
        list of attention matrix in encoder Self-Attention for the number of encoder layers,
        list of list of attention matrix in decoder Self-Attention for the number of decoder layers for the number of translation length,
        and list of list of attention matrix in decoder Source-Target-Attention for the number of decoder layers for the number of translation length.
      else,
        the logits.
  """
    
  # [Batch, SeqLen]
  assert src.ndim == 2                                                        

  if return_attn:
    memory, encoder_attention_list = model.apply(
        {"params": state.params}, 
        src,
        train = False,
        return_attn = True,
        method=model.encode)
  else:
    memory = model.apply(
        {"params": state.params}, 
        src,
        train = False,
        return_attn = False,
        method=model.encode)

  batch_size = src.shape[0]
  # [Batch, 1]
  translation_id = jnp.ones((batch_size, 1), 
                            dtype=config.id_dtype) * config.bos_idx           
  # [Batch, 1]
  translation_done = jnp.zeros((batch_size, 1), dtype=bool)                   

  decoder_attention_his, src_trg_attention_his = [], []
  for i in range(config.max_len):
      if return_attn:
        logits, decoder_attention_list, src_trg_attention_list = model.apply(
            {"params": state.params}, 
            translation_id, 
            src, 
            memory,
            train = False,
            return_attn = True,
            method=model.decode)
          
        decoder_attention_his.append(decoder_attention_list)
        src_trg_attention_his.append(src_trg_attention_list)
      else:
        # [Batch, SeqLen, VocabSize]
        logits = model.apply(
            {"params": state.params}, 
            translation_id, 
            src, 
            memory,
            train = False,
            return_attn = False,
            method=model.decode)                                              
 
      # [Batch, 1]       
      pred_id = jnp.argmax(logits, axis=-1)[:, -1][:, jnp.newaxis]            
      # [Batch, TranslationLen]
      translation_id = jnp.concatenate((translation_id, pred_id), axis=-1)    

      translation_done = jnp.where(pred_id == config.eos_idx, True, translation_done)
      if jnp.all(translation_done):
        break
    
  translation = []
  for sent_id in translation_id:
    sent = itos_en(list(sent_id))
    translation.append(sent)

  if return_attn:
    return translation, encoder_attention_list, decoder_attention_his, src_trg_attention_his
  else:
    return translation

Mount Google Drive

 学習をする開始する前にGoogle Driveをマウントする。Configでチェックポイントを保存するckpt_dirは'/content/drive/My Drive/checkpoints/translation’にしてあるので、そのまま実行する場合はこのタイミングでGoogleドライブにchekpoints/translationを作成する必要がある。

from google.colab import drive
drive.mount('/content/drive')

Initialize a Model and Training State

 学習を開始する前にmodelのinitialize、training_stateの作成、最新のチェックポイントがあればtraining_stateの引継ぎを行い、なければそのままゼロから学習を行う。

# from jax.config import config
# config.update("jax_debug_nans", True) # For Debug
# from jax.config import config
# config.update("jax_debug_nans", False) # For Usual

config = Config()
rng = jax.random.PRNGKey(config.seed)
rng, init_rng = jax.random.split(rng)
src_shape = (config.batch_size, config.max_len)
trg_shape = (config.batch_size, config.max_len)

model = Transformer(config)

print(model.tabulate(init_rng,
          jnp.ones(src_shape, config.id_dtype),
          jnp.ones(trg_shape, config.id_dtype)))
initial_variables = jax.jit(model.init)(init_rng,
                                      jnp.ones(src_shape, dtype=config.id_dtype),
                                      jnp.ones(trg_shape, dtype=config.id_dtype)
                                      )

def rsqrt_schedule(
    init_value: float,
    shift: int = 0,
):
  def schedule(count):
    return init_value * (count + shift)**-.5 * shift**.5
  return schedule

def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
  return optax.join_schedules([
      optax.linear_schedule(
          init_value=0, end_value=learning_rate, transition_steps=warmup_steps),
      rsqrt_schedule(init_value=learning_rate, shift=warmup_steps),
  ],
                              boundaries=[warmup_steps])

learning_rate_fn = create_learning_rate_schedule(
      learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=initial_variables["params"],
    tx=optax.adam(
        learning_rate=learning_rate_fn,
        b1=0.9,
        b2=0.98,
        eps=1e-9
    )
)
del initial_variables

if config.restore_checkpoints:
    latest_state_path = checkpoints.latest_checkpoint(config.ckpt_dir, config.ckpt_prefix)
    
    if latest_state_path is not None:
      state = checkpoints.restore_checkpoint(latest_state_path, state)
      last_epoch = int(re.findall(r"\d+", latest_state_path)[-1])
      print(f'Restore {latest_state_path} and restart training from epoch {last_epoch + 1}')
    else: 
      last_epoch = 0
      print('No checkpoints found. Start training from epoch 1')

else:
    last_epoch = 0
    print('Start training from epoch 1')

def get_h_m_s(seconds: int):
    min, sec = divmod(seconds, 60)
    hour, min = divmod(min, 60)
    return hour, min, sec

Training

 学習を行う。今回は150エポック行い、1エポックごとにvalidationする。また、サンプル文として「'Ich bin Student bei NAIST.'」を各エポックで翻訳する。

print("================START TRAINING================")
training_start_time = time.time()
training_start_date = datetime.datetime.now()
training_history = []
validation_history = []
sample_id = stoi_de(tokenize_de('Ich bin Student bei NAIST.'))
sample = jnp.asarray(sample_id)[jnp.newaxis, :]

for epoch in range(last_epoch + 1, last_epoch + 1 + config.num_epochs):
  epoch_start_time = time.time()
  is_last_epoch = epoch == last_epoch + config.num_epochs
  train_metrics = {}
  valid_metrics = {}
  print(f"Epoch_{epoch}")

  #TRAIN
  state, loss = train(state, model, train_iterator, config, dropout_rng=rng)
  print(f'Train      : loss {loss:.5f}')
  train_metrics["epoch"] = epoch
  train_metrics["loss"] = loss
  hour, min, sec = get_h_m_s(time.time() - epoch_start_time)
  print(f'Epoch Time : {hour:.0f}h {min:.0f}m {sec:.0f}s')
  train_metrics["hour"] = hour
  train_metrics["min"] = min
  train_metrics["sec"] = sec
  training_history.append(train_metrics)

  translation = translator(sample, state, model, config)[0]
  print(f'Translation: {translation}')

  #VALIDATE
  if epoch % config.valid_every_epochs == 0 or is_last_epoch:
    loss = valid(state, model, valid_iterator, config)
    print(f'Validate   : loss {loss:.5f}')
    valid_metrics["epoch"] = epoch
    valid_metrics["loss"] = loss
    validation_history.append(valid_metrics)

  #SAVE CHECKPOINTS
  if epoch % config.save_ckpt_every_epochs == 0 or is_last_epoch:
    checkpoints.save_checkpoint(
            ckpt_dir=config.ckpt_dir, prefix=config.ckpt_prefix,
            target=state, step=epoch, overwrite=True, keep=10)

  hour, min, sec = get_h_m_s(time.time() - training_start_time)
  print(f"-------------{hour:.0f}h {min:.0f}m {sec:.0f}s------------")

  if is_last_epoch:
    train_hour, train_min, train_sec = hour, min, sec

print("================FINISH TRAINING================")

#MAKE TRAINING LOG FILE
with open(config.ckpt_dir + f'/train_log_from_epoch{last_epoch+1}.txt', 'w') as f:
  text = f'Training Date: {training_start_date}\n'
  text += '===================Config===================\n'
  members = [attr for attr in dir(config) if not callable(getattr(config, attr)) and not attr.startswith("__")]
  for m in members:
    text += f'{m} : {getattr(config, m)}\n'
  text += '===================Training===================\n'
  for metrics in training_history:
    text += f'epoch_{metrics["epoch"]}: loss {metrics["loss"]:.5f} Epoch Time: {metrics["hour"]:.0f}h {metrics["min"]:.0f}m {metrics["sec"]:.0f}s\n'
  text += f'Whole Training took {train_hour:.0f}h {train_min:.0f}m {train_sec:.0f}s\n'
  text += '===================Validation===================\n'
  for metrics in validation_history:
    text += f'epoch_{metrics["epoch"]}: loss {metrics["loss"]:.5f}\n'
  f.write(text)
================START TRAINING================
Epoch_1
Train      : loss 71.77527
Epoch Time : 0h 0m 55s
Translation: ['<bos>', 'a', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '.', '<eos>']
-------------0h 0m 60s------------
Epoch_2
Train      : loss 51.95805
Epoch Time : 0h 0m 20s
Translation: ['<bos>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '.', '<eos>']
Validate   : loss 46.81571
-------------0h 1m 32s------------
Epoch_3
....


Epoch_150
Train      : loss 14.14630
Epoch Time : 0h 0m 20s
Translation: ['<bos>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', 'at', '<unk>', '.', '<eos>']
Validate   : loss 25.65130
-------------0h 59m 44s------------
================FINISH TRAINING================

学習が終了した。残念ながらサンプル文の'Ich bin Student bei NAIST.'は最後まで翻訳できなかったようであるが、以下でlossをプロットしてみると、train lossは減少し続け、validationは下がったのちにほんのわずかに上昇に転じ、あまり過学習しすぎないうちに学習が止まっているように見える。

train_epoch, train_loss = [], []
valid_epoch, valid_loss = [], []

for metrics in training_history:
  train_epoch.append(metrics["epoch"])
  train_loss.append(metrics["loss"])

for metrics in validation_history:
  valid_epoch.append(metrics["epoch"])
  valid_loss.append(metrics["loss"])

fig = plt.figure()
plt.subplots_adjust(hspace=0.6)

ax = fig.add_subplot(2, 1, 1)
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.set_title('training loss')
ax.plot(train_epoch, train_loss)

ax = fig.add_subplot(2, 1, 2)
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.set_title('validation loss')
ax.plot(valid_epoch, valid_loss)
plt.show()

Translate some Sentences in the Training Dataset

 訓練データに含まれる文章を10文だけ翻訳してみる。

#Check if it can translate sentences in train data.
for i, batch in enumerate(train_iterator):
  train_src = batch.src[:10]
  train_trg = batch.trg[:10]
  break

train_src, train_trg = jnp.asarray(train_src), jnp.asarray(train_trg)
translation, encoder_attention_list, decoder_attention_his, src_trg_attention_his = translator(train_src, state, model, config, return_attn=True)
for i, sent in enumerate(translation):
  print('src in train data: ', ' '.join(itos_de([id for id in train_src[i] if id not in config.special_idxes])))
  print('trg in train data: ', ' '.join(itos_en([id for id in train_trg[i] if id not in config.special_idxes])))
  print('model translation: ', ' '.join([token for token in list(sent) if token not in config.special_tokens]))
  print('------------------------------------')
src in train data:  ein junges mädchen <unk> die <unk> von <unk> .
trg in train data:  a young girl learns how electrical circuits work .
model translation:  a young girl <unk> the <unk> of her tent .
------------------------------------
src in train data:  ein kind hat sich verkleidet und hält einen <unk> in der hand .
trg in train data:  a child is dressed up in costume and holding a <unk> .
model translation:  a child is dressed up in costume and holding a <unk> . .
------------------------------------
src in train data:  ein mann und eine frau singen auf einer bühne .
trg in train data:  a man and woman are singing on a stage .
model translation:  a man and a woman are singing on a stage . .
------------------------------------
src in train data:  ein mann in einem gelben t-shirt hält einen großen camcorder .
trg in train data:  a man in a yellow t - shirt holds a large camcorder .
model translation:  a man in a yellow shirt holds a large camcorder . . .
------------------------------------
src in train data:  ein mann zieht sein gepäck auf rädern an einem fenster voll mit blühenden blumen vorbei .
trg in train data:  a man wheels his luggage past a window full of blooming flowers .
model translation:  a man is pulling his luggage on a window of blooming with flowers . .
------------------------------------
src in train data:  ein mann führt etwas für eine gruppe vor einem weißen haus auf .
trg in train data:  a man is performing for a group in front of a white house .
model translation:  a man performs for a group in front of a white house . . . .
------------------------------------
src in train data:  ein torhüter in einem gelben trikot schaut im hintergrund hoch , während ein spieler der gegnerischen mannschaft in einem grünen trikot auf allen vieren kniet und sein gesicht im rasen <unk> .
trg in train data:  as a goalie in a yellow jersey looks up in the distance , a player from the opposing team in a green jersey on his hands and knees has his face in the grass .
model translation:  a goalie in a yellow jersey looks up in the background as a player in a player in green jersey appears to be the opposing team and knees in the opposing in the background .
------------------------------------
src in train data:  leute richten sich ein um mit ihren gewehren zu schießen .
trg in train data:  men setting up to take a shot with their guns .
model translation:  people are setting up to take a shot with their guns . . . .
------------------------------------
src in train data:  ein paar sitzt auf einer brücke und beobachtet den sonnenuntergang .
trg in train data:  a couple sits to watch the sunset from a bridge .
model translation:  a couple sits on a bridge watching the sunset . .
------------------------------------
src in train data:  eine frau mit grünem t-shirt , rosafarbenem rucksack und einer mexikanischen flagge an der tasche steht in einer menschenmenge .
trg in train data:  a woman stands in a crowd wearing a green shirt , pink backpack , and a mexican flag attached to the bag .
model translation:  a woman in a green shirt , pink backpack , and an american flag stands at a bag in a crowd . . . . .
------------------------------------

サンプル文は全然だったが、学習データに含まれる文章は学習しているようである。

Plot Attention Metrics

 最後に、以下のようにAttention matrixを可視化する関数を実装し、先ほど翻訳した最初の文について、翻訳時のencoderの最終層のattentin matrix, 翻訳終了時のdecoderの最終層のself-attentionとsrc-trg attentionのattention matrixを、各headごとにそれぞれ可視化してみる。(Translator関数の仕様上、一度<eos>が出力されても全文の翻訳が終わるまで推論が終わらないため、<eos>がたくさん出力される。)

def plot_attention(query_tokens: List[str],
                   key_tokens: List[str], 
                   attention: jnp.array
                   ):
    
    attention = jnp.squeeze(attention)
    assert attention.ndim == 3
    num_heads, q_seqlen, k_seqlen = attention.shape[0], attention.shape[1], attention.shape[2]
    fig = plt.figure(figsize=(20, 30))

    for i in range(num_heads):
      attention_per_head = attention[i, ...]
      ax = fig.add_subplot(int(num_heads/2), 2, i+1)
      cax = ax.matshow(attention_per_head, cmap='bone')
      ax.tick_params(labelsize=12)
      ax.set_xticklabels([''] + key_tokens, rotation=80)
      ax.set_yticklabels([''] + query_tokens)
      ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
      ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()
    plt.close()
src, trg, translation_sent = train_src[0], train_trg[0], translation[0]
print('src in train data: ', ' '.join(itos_de([id for id in src if id not in config.special_idxes])))
print('trg in train data: ', ' '.join(itos_en([id for id in trg if id not in config.special_idxes])))
print('model translation: ', ' '.join([token for token in list(translation_sent)])
src in train data:  ein junges mädchen <unk> die <unk> von <unk> .
trg in train data:  a young girl learns how electrical circuits work .
model translation:  a young girl <unk> the <unk> of her tent .
plot_attention(itos_de(src), itos_de(src), encoder_attention_list[-1])

plot_attention(translation_sent, translation_sent, decoder_attention_his[-1][-1]

plot_attention(translation_sent, itos_de(src), src_trg_attention_his[-1][-1])

各headにおいて各tokenがほかのtokenとどのように関連性があると学習されているか確認できる。また、key方向である列方向で<pad>となっている箇所は真っ黒になっており、attentionが0になっていることが分かる。

Acknowledgment

 本記事の執筆にあたり、NAIST渡辺研究室の渡辺太郎教授、出口祥之さん、中村研究室同期の田中康紀さんにご助言いただきました。この場を借りてお礼申し上げます。ありがとうございました。

Reference

Paper

[1] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, Illia Polosukhin (2017). Attention is All you Need. Advances in Neural Information Processing Systems, 2017-Decem, 5999–6009.

GitHub