Preface
この記事ではJAX学習記録シリーズ第二弾として、Pytreeについてまとめる。基本的に第一弾は読んでる前提で書くので、もし読んでない方がいたら読むことをお勧めする。Pytreeは以下の第一弾のブログでも出てきたように、モデルパラメータを格納するデータ構造(本当はそれ以外にも使える。)であり、JAXの素晴らしいポイントの一つは、jax.gradにpytreeの形式でパラメータを持つ関数を与えたとしても、各パラメータに対してその関数の勾配を計算できることである。
この記事は、以下のJAXおよびFlaxの公式ドキュメント読み、一部修正や改善を行いながら内容をまとめた。ちなみに、JAX編が終わったらFlax編に突入する。
目次
Pytree
What's a pytree?
JAXでは、コンテナのようなオブジェクト(container-like object)(例えばdictのlistのlist、arrayのdictというような、ネストされた構造を持つオブジェクト)から構成されるツリー状の構造(tree-like structure)を表現するために、「pytree」という用語を用いる。クラスは、pytree container registry(デフォルトでlist, tuple, dict)にあるクラスである場合は、コンテナ的(container-like)であると見なされる。つまり、
- pytree container registryにないクラスのオブジェクトは、すべてpytreeのleafと見なされる。
- pytree container registryにあるクラスのオブジェクトは、すべてpytreeのnodeと見なされる。
pytree container registryでは、各コンテナ型は、コンテナ型のインスタンスを(children, metadata)のペアに変換する関数と、そのペアをコンテナ型のインスタンスに変換する関数の、二つの関数のペアと共に登録されている。もちろん、ユーザーが定義したクラスをpytree container registryに新たに登録することもできる。
pytreeの例:
[1, "a", object()] # 3 leaves (1, (2, 3), ()) # 3 leaves [1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves
Common Pytree Functions
JAXは、jax.tree_util.tree_* というprefixを持つ、pytreeを扱う関数を用意している。JAXのdocumentを読むと、2022/11/09現在において、jax.tree_* を使うやり方が書いてあるが、例えばjax.tree_leaves()を使おうとすると以下の警告が出る。これによると、jax.tree_* は将来的にjax.tree_util.tree_*に移動するらしく、将来的に消される予定らしいので、jax.tree_util.tree_* をprefixに持つ関数を使うことにする。
FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
本筋に戻る。jax.tree_util.tree_leaves()を使うと、以下のようにtreeからleafを抽出することができる。
import jax import jax.numpy as jnp from jax.tree_util import tree_leaves example_trees = [ [1, 'a', object()], (1, (2, 3), ()), [1, {'k1': 2, 'k2': (3, 4)}, 5], {'a': 2, 'b': (2, 3)}, jnp.array([1, 2, 3]), ] # Let's see how many leaves they have: for pytree in example_trees: leaves = tree_leaves(pytree) print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7fded60bb8c0>] has 3 leaves: [1, 'a', <object object at 0x7fded60bb8c0>] (1, (2, 3), ()) has 3 leaves: [1, 2, 3] [1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5] {'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3] DeviceArray([1, 2, 3], dtype=int32) has 1 leaves: [DeviceArray([1, 2, 3], dtype=int32)]
先ほど紹介したようにlist, tuple, dictがnodeとして扱われ、その中の要素がleafとして認識されているが、文字列やJAX arrayなど、そのほかのオブジェクトはleafとして認識されていることが分かる。
jax.tree_util.tree_*で最もよく使われるであろう関数は、tree_mapである。これは、Pythonネイティブのmap関数と同じように動作し、pytreeに格納したパラメータを更新する際に利用される、非常に大事な関数である。 基本的には、以下のように第一引数に関数をあたえ、第二引数以降にその入力となるpytreeを渡す。
from jax.tree_util import tree_map list_of_lists = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ] tree_map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
tree_mapは、複数の引数に対しても動作する。
another_list_of_lists = list_of_lists
tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
この際、引数として与えられるpytreeの構造は正確に一致している必要がある。つまり、引数として与えられたpytreeのlistは同じ数の要素を持ち、辞書は同じkeyを持つ必要がある。
Custom pytree nodes
前述したとおり、pytree container registryには自作クラスを登録できる。登録していない場合には、自作クラスのインスタンスがleafを持っていたとしても、インスタンスそのものがleafとして認識され、インスタンスが持っているleafはleafとして認識されない。登録することによりインスタンス自体はnodeとして認識され、持っているleafもleafとして認識されるようになる。例えば以下のように自作クラスを定義してインスタンスのリストをtree_leavesに渡すと、それぞれのインスタンスがleafとして認識される。
class MyContainer: """A named container.""" def __init__(self, name: str, a: int, b: int, c: int): self.name = name self.a = a self.b = b self.c = c
tree_leaves([ MyContainer('Alice', 1, 2, 3), MyContainer('Bob', 4, 5, 6) ])
[<__main__.MyContainer at 0x7fdec166ce50>, <__main__.MyContainer at 0x7fded89ba490>]
そこで、以下のようにコンテナからleafと中間的なデータを取り出す関数と、逆にleafと中間的なデータからコンテナを作製する関数を作製し、jax.tree_util.register_pytree_node()でpytree container registerに登録する。
from typing import Tuple, Iterable from jax.tree_util import register_pytree_node def flatten_MyContainer(container) -> Tuple[Iterable[int], str]: """Returns an iterable over container contents, and aux data.""" flat_contents = [container.a, container.b, container.c] # we don't want the name to appear as a child, so it is auxiliary data. # auxiliary data is usually a description of the structure of a node, # e.g., the keys of a dict -- anything that isn't a node's children. aux_data = container.name return flat_contents, aux_data def unflatten_MyContainer( aux_data: str, flat_contents: Iterable[int]) -> MyContainer: """Converts aux data and the flat contents into a MyContainer.""" return MyContainer(aux_data, *flat_contents) register_pytree_node( MyContainer, flatten_MyContainer, unflatten_MyContainer) tree_leaves([ MyContainer('Alice', 1, 2, 3), MyContainer('Bob', 4, 5, 6) ])
[1, 2, 3, 4, 5, 6]
以上のように、きちんと自作コンテナが持つleafが認識されていることが分かる。
また、以下のように@register_pytree_node_classで登録したいクラスをデコレートする方法もある。
from jax.tree_util import register_pytree_node_class from typing import Tuple, Iterable @register_pytree_node_class class MyContainer2: def __init__(self, name, a, b, c): self.name = name self.a = a self.b = b self.c = c def tree_flatten(self) -> Tuple[Iterable[int], str]: aux_data = self.name return (self.a, self.b, self.c), aux_data @classmethod def tree_unflatten(cls, aux_data: str, children: Iterable[int]): return cls(aux_data, *children)
このやり方だとスマートに書けはするが、クラスの中にそれ自身のクラスのインスタンスからleafと中間的なデータを取り出すメソッドと、leafと中間的なデータから自身のクラスのインスタンスを作り出すメソッドを定義する必要があり少しややこしい。しかもそこまでコードが短くなるわけでもないのであまり良いやり方のようには思えない。
Pytree for ML model parameters
この章ではPytreeをMLモデルのパラメータとして用いる方法を示すために、Linear RegressionとMLPの学習方法についてみてみる。
Linear Regression
ここでは、前の章に引き続きLinear Regressionを行う。ただし、問題設定がより複雑で、パラメータが行列の形となる。 データが与えられた際、パラメータを持つモデルを考え、以下の式で表される二乗誤差を最小にするようにパラメータを学習する。
まずはモデルとloss関数を定義する。
def model(params, x): return jnp.dot(x, params['W']) + params['b'] def loss(params, x_batched,y_batched): def squared_error(x,y): y_pred = model(params, x) return jnp.inner(y-y_pred, y-y_pred) / 2.0 return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
loss関数内でjax.vmapという関数を使っているが、これは、jaxにおいて非常に重要な関数の一つであり、また別の記事で解説する。modelはパラメータと一つのxのサンプルを受けとり、出力としてm次元のベクトルを返す関数であり、その一つのサンプルに対してlossを計算するには、loss関数内で定義されたsquared_error関数を計算すればよい。ここでjax.vmapは、それらの計算を与えられたすべてのサンプル(1バッチ)に対してまとめて行っている。そして、その平均を取ってそのバッチのlossとしている。
サンプル数を20、の次元を10 ()、の次元を5 ()とする。と、真のパラメータとなる, を標準正規分布からサンプリングし、それにさらにノイズを足して、という形でを生成する。
import jax from jax import numpy as jnp, random # Set problem dimensions. n_samples = 20 x_dim = 10 y_dim = 5 # Generate random ground truth W and b. key = random.PRNGKey(0) k1, k2 = random.split(key) W_true = random.normal(k1, (x_dim, y_dim)) b_true = random.normal(k2, (y_dim,)) true_params = {'W': W_true, 'b': b_true} # Generate samples with additional noise. key_sample, key_noise = random.split(k1) x_samples = random.normal(key_sample, (n_samples, x_dim)) y_samples = model(true_params, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim)) print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)
これでサンプル数20のxとyを生成できた。
ここでjax.gradを使いloss関数の勾配を計算してみる。ここで、true_paramsは先ほど定義したようにvalueにパラメータを持つdict形式のpytreeとなっているが、このようにloss関数のパラメータがpytreeで与えられても、それぞれのパラメータに対して勾配を計算できるのがJAXの素晴らしいポイントである。
jax.grad(loss)(true_params, x_samples, y_samples)
{'W': DeviceArray([[ 0.03074035, 0.00365056, -0.02740437, 0.03342171, 0.00768137], [ 0.00480341, -0.00601222, 0.01753974, 0.02849675, 0.01315357], [ 0.02740791, 0.00509001, 0.01633988, 0.01654152, -0.03833748], [ 0.04439825, 0.01456456, -0.01553174, 0.00361327, -0.00520618], [-0.01146152, 0.01010936, -0.02445415, -0.02912016, -0.00614015], [-0.01862438, -0.01283839, -0.00399569, -0.00660799, 0.01106221], [ 0.03790367, 0.01057525, 0.00876093, -0.00756639, 0.00785039], [ 0.0138653 , -0.00148796, -0.02050714, 0.01991827, 0.01749898], [ 0.01063178, -0.03577084, -0.02890723, 0.01893806, 0.02628964], [ 0.0198197 , 0.01517353, -0.01874115, -0.01610945, -0.02562965]], dtype=float32), 'b': DeviceArray([ 0.0193719 , -0.00509552, -0.00823335, -0.01260918, 0.00415814], dtype=float32)}
もちろん勾配は与えたパラメータと同じ形状( )となるため、tree_mapの引数に更新前のパラメータを保持したpytreeと、勾配を計算した結果を格納したpytreeを渡すことができる。
それでは早速モデルを学習してみる。バッチサイズ=サンプルサイズ=20とする。パラメータとを0で初期化して、初期化したpytreeを新たに作る。updateでは、tree_mapを使ってパラメータの更新を行い、100エポック学習する。jax.jitでupdate関数がデコレートされているが、これはjust-in-timeコンパイルをするために必要なもので、この記事では特に触れない。計算処理を早くするおまじないのようなものである。
# Initialize estimated W and b with zeros. W = jnp.zeros_like(W) b = jnp.zeros_like(b) params = {'W': W, 'b': b} # Always remember to jit! @jax.jit def update(params, learning_rate, x_samples, y_samples): params = jax.tree_util.tree_map( lambda p, g: p - learning_rate * g, params, jax.grad(loss)(params, x_samples, y_samples)) return params learning_rate = 0.3 # Gradient step size. print('Loss for "true" W,b: ', loss({'W': W, 'b': b}, x_samples, y_samples)) for i in range(101): # Perform one gradient update. params = update(params, learning_rate, x_samples, y_samples) if (i % 5 == 0): print(f"Loss step {i}: ", loss(params, x_samples, y_samples))
Loss for "true" W,b: 0.023639793 Loss step 0: 10.971409 Loss step 5: 1.0798324 Loss step 10: 0.37958255 Loss step 15: 0.17855294 Loss step 20: 0.09441521 Loss step 25: 0.054522194 Loss step 30: 0.03448924 Loss step 35: 0.024058027 Loss step 40: 0.018480862 Loss step 45: 0.015438682 Loss step 50: 0.01375394 Loss step 55: 0.0128103 Loss step 60: 0.012277315 Loss step 65: 0.011974388 Loss step 70: 0.011801447 Loss step 75: 0.011702419 Loss step 80: 0.011645544 Loss step 85: 0.011612837 Loss step 90: 0.011594015 Loss step 95: 0.011583163 Loss step 100: 0.011576912
いい感じに学習できた。
また、このようにlossの値を記録したい場合には、勾配を計算する際にjax.gradではなく、jax.value_and_gradを使うと、以下のようにlossを計算し直すことなく記録できる。
# Using jax.value_and_grad instead: loss_grad_fn = jax.value_and_grad(loss) for i in range(101): # Note that here the loss is computed before the param update. loss_val, grads = loss_grad_fn(params, x_samples, y_samples) params = jax.tree_util.tree_map( lambda p, g: p - learning_rate * g, params, grads) if (i % 5 == 0): print(f"Loss step {i}: ", loss_val)
MLP
pytreeを使って簡単なMLPを作製して、訓練してみる。今回の例ではパラメータのpytreeはもう少し複雑になる。
import jax import numpy as np def init_mlp_params(layer_widths): params = [] for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]): params.append( dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in), biases=np.ones(shape=(n_out,)) ) ) return params params = init_mlp_params([1, 128, 128, 1])
ここで、layer_widths[:-1]は最後から二番目の要素までのリストであり、ここでは[1, 128, 128]のこと。layer_widths[1:]は二番目の要素から最後までのリストであり、ここでは[128, 128, 1]のことである。また、weightsは各層において標準偏差がの正規分布からサンプリングされた値で初期化され、biasは1で初期化されている。 tree_mapを使ってパラメータの形状を確認する。
jax.tree_util.tree_map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)}, {'biases': (128,), 'weights': (128, 128)}, {'biases': (1,), 'weights': (128, 1)}]
以上で、各次元が1, 128, 128, 1の四層のMLPが定義できた。
次に、forward関数とloss関数、update関数を定義する。
def forward(params, x): *hidden, last = params for layer in hidden: x = jax.nn.relu(x @ layer['weights'] + layer['biases']) return x @ last['weights'] + last['biases'] def loss(params, x, y): return jnp.mean((forward(params, x) - y) ** 2) LEARNING_RATE = 0.0001 @jax.jit def update(params, x, y): grads = jax.grad(loss)(params, x, y) return jax.tree_util.tree_map( lambda p, g: p - LEARNING_RATE * g, params, grads )
forward関数内で、「@」という演算が行われているが、これは行列積のこと( What’s New In Python 3.5 — Python 3.11.0 documentation)で、jnp.dotと基本的には同じ。 Linear Regressionでも説明した通り、update関数内で勾配を計算する際、第一引数として渡されているparamsについてloss関数の微分が行われ、返り値であるgradsもparamsと同じ形状をしたpytreeである。そして例のごとくupdate関数はtree_mapを使い値の更新を行い、updateされたparamsを返す。
次に、モデルの訓練を行う。正規分布からサンプリングした値からデータを作製し、1000エポック学習させる。
import matplotlib.pyplot as plt xs = np.random.normal(size=(128, 1)) ys = xs ** 2 for _ in range(1000): params = update(params, xs, ys) plt.scatter(xs, ys) plt.scatter(xs, forward(params, xs), label='Model prediction') plt.legend();