izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録。

JAX学習記録②ーPytree

Preface

 この記事ではJAX学習記録シリーズ第二弾として、Pytreeについてまとめる。基本的に第一弾は読んでる前提で書くので、もし読んでない方がいたら読むことをお勧めする。Pytreeは以下の第一弾のブログでも出てきたように、モデルパラメータを格納するデータ構造(本当はそれ以外にも使える。)であり、JAXの素晴らしいポイントの一つは、jax.gradにpytreeの形式でパラメータを持つ関数を与えたとしても、各パラメータに対してその関数の勾配を計算できることである。

izmyon.hatenablog.com

 この記事は、以下のJAXおよびFlaxの公式ドキュメント読み、一部修正や改善を行いながら内容をまとめた。ちなみに、JAX編が終わったらFlax編に突入する。

目次

Pytree

What's a pytree?

 JAXでは、コンテナのようなオブジェクト(container-like object)(例えばdictのlistのlist、arrayのdictというような、ネストされた構造を持つオブジェクト)から構成されるツリー状の構造(tree-like structure)を表現するために、「pytree」という用語を用いる。クラスは、pytree container registry(デフォルトでlist, tuple, dict)にあるクラスである場合は、コンテナ的(container-like)であると見なされる。つまり、

  1. pytree container registryにないクラスのオブジェクトは、すべてpytreeのleafと見なされる。
  2. pytree container registryにあるクラスのオブジェクトは、すべてpytreeのnodeと見なされる。

 pytree container registryでは、各コンテナ型は、コンテナ型のインスタンスを(children, metadata)のペアに変換する関数と、そのペアをコンテナ型のインスタンスに変換する関数の、二つの関数のペアと共に登録されている。もちろん、ユーザーが定義したクラスをpytree container registryに新たに登録することもできる。

 pytreeの例:

[1, "a", object()]  # 3 leaves
(1, (2, 3), ())  # 3 leaves
[1, {"k1": 2, "k2": (3, 4)}, 5]  # 5 leaves

Common Pytree Functions 

 JAXは、jax.tree_util.tree_* というprefixを持つ、pytreeを扱う関数を用意している。JAXのdocumentを読むと、2022/11/09現在において、jax.tree_* を使うやり方が書いてあるが、例えばjax.tree_leaves()を使おうとすると以下の警告が出る。これによると、jax.tree_* は将来的にjax.tree_util.tree_*に移動するらしく、将来的に消される予定らしいので、jax.tree_util.tree_* をprefixに持つ関数を使うことにする。

FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.

 本筋に戻る。jax.tree_util.tree_leaves()を使うと、以下のようにtreeからleafを抽出することができる。

import jax
import jax.numpy as jnp
from jax.tree_util import tree_leaves

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Let's see how many leaves they have:
for pytree in example_trees:
  leaves = tree_leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x7fded60bb8c0>]   has 3 leaves: [1, 'a', <object object at 0x7fded60bb8c0>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
DeviceArray([1, 2, 3], dtype=int32)           has 1 leaves: [DeviceArray([1, 2, 3], dtype=int32)]

 先ほど紹介したようにlist, tuple, dictがnodeとして扱われ、その中の要素がleafとして認識されているが、文字列やJAX arrayなど、そのほかのオブジェクトはleafとして認識されていることが分かる。

 jax.tree_util.tree_*で最もよく使われるであろう関数は、tree_mapである。これは、Pythonネイティブのmap関数と同じように動作し、pytreeに格納したパラメータを更新する際に利用される、非常に大事な関数である。 基本的には、以下のように第一引数に関数をあたえ、第二引数以降にその入力となるpytreeを渡す。

from jax.tree_util import tree_map
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

tree_map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

tree_mapは、複数の引数に対しても動作する。

another_list_of_lists = list_of_lists
tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

この際、引数として与えられるpytreeの構造は正確に一致している必要がある。つまり、引数として与えられたpytreeのlistは同じ数の要素を持ち、辞書は同じkeyを持つ必要がある。

Custom pytree nodes

 前述したとおり、pytree container registryには自作クラスを登録できる。登録していない場合には、自作クラスのインスタンスleafを持っていたとしても、インスタンスそのものがleafとして認識され、インスタンスが持っているleafleafとして認識されない。登録することによりインスタンス自体はnodeとして認識され、持っているleafleafとして認識されるようになる。例えば以下のように自作クラスを定義してインスタンスのリストをtree_leavesに渡すと、それぞれのインスタンスleafとして認識される。

class MyContainer:
  """A named container."""

  def __init__(self, name: str, a: int, b: int, c: int):
    self.name = name
    self.a = a
    self.b = b
    self.c = c
tree_leaves([
    MyContainer('Alice', 1, 2, 3),
    MyContainer('Bob', 4, 5, 6)
])
[<__main__.MyContainer at 0x7fdec166ce50>,
 <__main__.MyContainer at 0x7fded89ba490>]

そこで、以下のようにコンテナからleafと中間的なデータを取り出す関数と、逆にleafと中間的なデータからコンテナを作製する関数を作製し、jax.tree_util.register_pytree_node()でpytree container registerに登録する。

from typing import Tuple, Iterable
from jax.tree_util import register_pytree_node

def flatten_MyContainer(container) -> Tuple[Iterable[int], str]:
  """Returns an iterable over container contents, and aux data."""
  flat_contents = [container.a, container.b, container.c]

  # we don't want the name to appear as a child, so it is auxiliary data.
  # auxiliary data is usually a description of the structure of a node,
  # e.g., the keys of a dict -- anything that isn't a node's children.
  aux_data = container.name
  return flat_contents, aux_data

def unflatten_MyContainer(
    aux_data: str, flat_contents: Iterable[int]) -> MyContainer:
  """Converts aux data and the flat contents into a MyContainer."""
  return MyContainer(aux_data, *flat_contents)

register_pytree_node(
    MyContainer, flatten_MyContainer, unflatten_MyContainer)

tree_leaves([
    MyContainer('Alice', 1, 2, 3),
    MyContainer('Bob', 4, 5, 6)
])
[1, 2, 3, 4, 5, 6]

以上のように、きちんと自作コンテナが持つleafが認識されていることが分かる。

 また、以下のように@register_pytree_node_classで登録したいクラスをデコレートする方法もある。

from jax.tree_util import register_pytree_node_class
from typing import Tuple, Iterable

@register_pytree_node_class
class MyContainer2:
  def __init__(self, name, a, b, c):
    self.name = name
    self.a = a
    self.b = b
    self.c = c

  def tree_flatten(self) -> Tuple[Iterable[int], str]:
    aux_data = self.name
    return (self.a, self.b, self.c), aux_data

  @classmethod
  def tree_unflatten(cls, aux_data: str, children: Iterable[int]):
    return cls(aux_data, *children)

 このやり方だとスマートに書けはするが、クラスの中にそれ自身のクラスのインスタンスからleafと中間的なデータを取り出すメソッドと、leafと中間的なデータから自身のクラスのインスタンスを作り出すメソッドを定義する必要があり少しややこしい。しかもそこまでコードが短くなるわけでもないのであまり良いやり方のようには思えない。

Pytree for ML model parameters

 この章ではPytreeをMLモデルのパラメータとして用いる方法を示すために、Linear RegressionとMLPの学習方法についてみてみる。

Linear Regression

 ここでは、前の章に引き続きLinear Regressionを行う。ただし、問題設定がより複雑で、パラメータが行列の形となる。 データ \{ (x_i, y_i), i \in \{1, \cdots, k \}, x_i \in \mathbb{R}^n, y_i \in \mathbb{R}^m \}が与えられた際、パラメータ W \in \mathbb{R}^{m×n} , b \in \mathbb{R}^mを持つモデル f_{W, b}(x) = Wx + bを考え、以下の式で表される二乗誤差を最小にするようにパラメータを学習する。

 \displaystyle  L \left( W, b \right) = \dfrac{1}{k} \sum_{i=1}^{k} \dfrac{1}{2} \left || y_i - f_{W, b}(x_i) \right ||_2^2

まずはモデルとloss関数を定義する。

def model(params, x):
  return jnp.dot(x, params['W']) + params['b']

def loss(params, x_batched,y_batched):
  def squared_error(x,y):
    y_pred = model(params, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

loss関数内でjax.vmapという関数を使っているが、これは、jaxにおいて非常に重要な関数の一つであり、また別の記事で解説する。modelはパラメータと一つのxのサンプルを受けとり、出力としてm次元のベクトルを返す関数であり、その一つのサンプルに対してlossを計算するには、loss関数内で定義されたsquared_error関数を計算すればよい。ここでjax.vmapは、それらの計算を与えられたすべてのサンプル(1バッチ)に対してまとめて行っている。そして、その平均を取ってそのバッチのlossとしている。

サンプル数を20、 xの次元を10 ( n=10)、 yの次元を5 ( m=5)とする。 xと、真のパラメータとなる W_{true},  b_{true}を標準正規分布からサンプリングし、それにさらにノイズを足して、 y = W_{true}x + b_{true} + εという形で yを生成する。

import jax
from jax import numpy as jnp, random
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W_true = random.normal(k1, (x_dim, y_dim))
b_true = random.normal(k2, (y_dim,))
true_params = {'W': W_true, 'b': b_true}

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = model(true_params, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)

これでサンプル数20のxとyを生成できた。

ここでjax.gradを使いloss関数の勾配を計算してみる。ここで、true_paramsは先ほど定義したようにvalueにパラメータを持つdict形式のpytreeとなっているが、このようにloss関数のパラメータがpytreeで与えられても、それぞれのパラメータに対して勾配を計算できるのがJAXの素晴らしいポイントである。

jax.grad(loss)(true_params, x_samples, y_samples)
{'W': DeviceArray([[ 0.03074035,  0.00365056, -0.02740437,  0.03342171,
                0.00768137],
              [ 0.00480341, -0.00601222,  0.01753974,  0.02849675,
                0.01315357],
              [ 0.02740791,  0.00509001,  0.01633988,  0.01654152,
               -0.03833748],
              [ 0.04439825,  0.01456456, -0.01553174,  0.00361327,
               -0.00520618],
              [-0.01146152,  0.01010936, -0.02445415, -0.02912016,
               -0.00614015],
              [-0.01862438, -0.01283839, -0.00399569, -0.00660799,
                0.01106221],
              [ 0.03790367,  0.01057525,  0.00876093, -0.00756639,
                0.00785039],
              [ 0.0138653 , -0.00148796, -0.02050714,  0.01991827,
                0.01749898],
              [ 0.01063178, -0.03577084, -0.02890723,  0.01893806,
                0.02628964],
              [ 0.0198197 ,  0.01517353, -0.01874115, -0.01610945,
               -0.02562965]], dtype=float32),
 'b': DeviceArray([ 0.0193719 , -0.00509552, -0.00823335, -0.01260918,
               0.00415814], dtype=float32)}

もちろん勾配は与えたパラメータと同じ形状( W \in \mathbb{R}^{m×n} , b \in \mathbb{R}^m )となるため、tree_mapの引数に更新前のパラメータを保持したpytreeと、勾配を計算した結果を格納したpytreeを渡すことができる。

それでは早速モデルを学習してみる。バッチサイズ=サンプルサイズ=20とする。パラメータ W bを0で初期化して、初期化したpytreeを新たに作る。updateでは、tree_mapを使ってパラメータの更新を行い、100エポック学習する。jax.jitでupdate関数がデコレートされているが、これはjust-in-timeコンパイルをするために必要なもので、この記事では特に触れない。計算処理を早くするおまじないのようなものである。

# Initialize estimated W and b with zeros.
W = jnp.zeros_like(W)
b = jnp.zeros_like(b)
params = {'W': W, 'b': b}

# Always remember to jit!
@jax.jit
def update(params, learning_rate, x_samples, y_samples):
  params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params,
        jax.grad(loss)(params, x_samples, y_samples))
  return params

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', loss({'W': W, 'b': b}, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  params = update(params, learning_rate, x_samples, y_samples)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", loss(params, x_samples, y_samples))
Loss for "true" W,b:  0.023639793
Loss step 0:  10.971409
Loss step 5:  1.0798324
Loss step 10:  0.37958255
Loss step 15:  0.17855294
Loss step 20:  0.09441521
Loss step 25:  0.054522194
Loss step 30:  0.03448924
Loss step 35:  0.024058027
Loss step 40:  0.018480862
Loss step 45:  0.015438682
Loss step 50:  0.01375394
Loss step 55:  0.0128103
Loss step 60:  0.012277315
Loss step 65:  0.011974388
Loss step 70:  0.011801447
Loss step 75:  0.011702419
Loss step 80:  0.011645544
Loss step 85:  0.011612837
Loss step 90:  0.011594015
Loss step 95:  0.011583163
Loss step 100:  0.011576912

いい感じに学習できた。

また、このようにlossの値を記録したい場合には、勾配を計算する際にjax.gradではなく、jax.value_and_gradを使うと、以下のようにlossを計算し直すことなく記録できる。

# Using jax.value_and_grad instead:
loss_grad_fn = jax.value_and_grad(loss)
for i in range(101):
  # Note that here the loss is computed before the param update.
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads)
    if (i % 5 == 0):
        print(f"Loss step {i}: ", loss_val)

MLP

 pytreeを使って簡単なMLPを作製して、訓練してみる。今回の例ではパラメータのpytreeはもう少し複雑になる。

import jax
import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

ここで、layer_widths[:-1]は最後から二番目の要素までのリストであり、ここでは[1, 128, 128]のこと。layer_widths[1:]は二番目の要素から最後までのリストであり、ここでは[128, 128, 1]のことである。また、weightsは各層において標準偏差 \sqrt{\frac{2}{n_{in}}}正規分布からサンプリングされた値で初期化され、biasは1で初期化されている。 tree_mapを使ってパラメータの形状を確認する。

jax.tree_util.tree_map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

以上で、各次元が1, 128, 128, 1の四層のMLPが定義できた。

次に、forward関数とloss関数、update関数を定義する。

def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

def loss(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

LEARNING_RATE = 0.0001

@jax.jit
def update(params, x, y):
  grads = jax.grad(loss)(params, x, y)
  return jax.tree_util.tree_map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

forward関数内で、「@」という演算が行われているが、これは行列積のこと( What’s New In Python 3.5 — Python 3.11.0 documentation)で、jnp.dotと基本的には同じ。 Linear Regressionでも説明した通り、update関数内で勾配を計算する際、第一引数として渡されているparamsについてloss関数の微分が行われ、返り値であるgradsもparamsと同じ形状をしたpytreeである。そして例のごとくupdate関数はtree_mapを使い値の更新を行い、updateされたparamsを返す。

次に、モデルの訓練を行う。正規分布からサンプリングした値からデータを作製し、1000エポック学習させる。

import matplotlib.pyplot as plt

xs = np.random.normal(size=(128, 1))
ys = xs ** 2

for _ in range(1000):
  params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.scatter(xs, forward(params, xs), label='Model prediction')
plt.legend();

いい感じに学習できた。 やはりJAXの自動微分は素晴らしい!!次回は自動微分について書く。