izmyonの日記

奈良の山奥で研究にいそしむ大学院生の学習記録

JAXの自動微分の仕組み ―Dougal Maclaurin氏の博士論文およびMatthew J Johnson氏の講演から①

Preface

 本シリーズでは、JAXの開発者であり、Autogradの開発者でもあるDougal Maclaurin氏と、Matthew James Johnson氏に許可をいただき、Maclaurin氏の博士論文およびJohnson氏の講演の内容から、少しずつJAXの自動微分について解説していく。

 本記事では第一弾として、Maclaurin氏の博士論文の第2.5節、Johnson氏の講演の27分当たりまで、最初のセクションの内容について書く。特に、前者については本文自体が非常にわかりやすかったためにそのまま翻訳し、後者については筆者(izmyon)の理解に基づきスライドを抜粋して解説している。

 本記事は以下の記事のより平易な解説として、補足のような位置づけとなっている。とりあえず読むことをお勧めする。

izmyon.hatenablog.com

Dougal Maclaurin氏

dougalmaclaurin.com

博士論文:Modeling, Inference and Optimization with ComposableDifferentiable Procedures

Matthew James Johnson氏

people.csail.mit.edu

講演:Automatic Differentiation: Deep Learning (DLSS) and Reinforcement Learning (RLSS) Summer School, Montreal 2017

Modeling, Inference and Optimization with ComposableDifferentiable Procedures

2 Background

2.5 Computing Gradients in Reverse Accumulation Mode

連続最適化問題や推論問題は、目的関数や確率の対数(対数尤度、情報量など)の勾配を利用することができれば、はるかに容易になる。これは特に高次元の関数、 \mathbb{R}^n \rightarrow \mathbb{R} に当てはまり、各勾配の評価はD個の関数評価を追加することと等しいからである。時間コストは関数自体の評価コストと比べたら小さな定数のファクターでしかないため、適切な評価戦略を用いることで、勾配は非常に楽に求めることができる。本節では、この評価戦略(逆(積算)モード微分法、あるいはニューラルネットワークの世界ではバックプロパゲーションと呼ばれる)を説明する。

ベクトルからスカラーへの関数、 \mathbb{R}^D \rightarrow \mathbb{R}が、既知のヤコビアンを持つ(様々なM, Nに対する)原始関数 \mathbb{R}^M \rightarrow \mathbb{R}^Nの集合で構成されている場合、その合成の勾配は、チェインルールに従って、原始関数のヤコビアンの積で与えられる。しかし、チェインルールはヤコビアンを乗じる順序を規定していない。

具体的には、 F= D \circ C \circ B \circ Aという4つの原始関数の合成として定義された F: \mathbb{R}^D \rightarrow \mathbb{R}を考える。中間値を参照できるように、この関数を分解する。

 F(x) = y \quad(x \in \mathbb{R}^D, y \in \mathbb{R} )  where \  y = D(c), \ c = C(b) , \ b= B(a), \ a = A(x) \tag{2.14}

 F F'の勾配(またはヤコビアン)は以下のように与えられる。

 \begin{align} 
F'(x) = &\frac{\partial y}{\partial x} \\
where \  &\frac{\partial y}{\partial x} =  \frac{\partial y}{\partial c}  \frac{\partial c}{\partial b}  \frac{\partial b}{\partial a}  \frac{\partial a}{\partial x} \tag{2.15} \\
&\frac{\partial y}{\partial c} = D'(c), \  \frac{\partial c}{\partial b} = C'(b), \ \frac{\partial b}{\partial a} = B'(a), \  \frac{\partial a}{\partial x} = A'(x) \tag{2.16} 
\end{align}

ここで、 A′ B′ C′および D′ A B Cおよび Dヤコビアンを計算する関数である。我々は、入力として x \in \mathbb{R}^Dを、出力として y \in \mathbb{R}を常に用いる。ここで、 yスカラーであり、対して xは(巨大かもしれない)ベクトルである。

 yは行列の乗算は結合的であるため、式2.15のヤコビアンの積を任意の順序で評価することができる。左から順に評価することを「逆積算モード」(または単にリバースモード)、右から順に評価することを 「順積算モード」(または単にフォアワードモード)と呼ぶことにする。

計算される中間値の大きさが大きく異なることに注目してほしい。 順方向モードでは、これらは \frac{\partial b}{\partial x}のようなヤコビアンである。 x \in \mathbb{R}^Dはベクトルなので、これは対応する値 b D倍の要素を含んでいる。リバースモードでは、 \frac{\partial y}{\partial b}のような値を計算する。 y \in \mathbb{R}スカラーなので、これは対応する値 bと同じ数の要素を含んでいる。

したがって、原始関数のヤコビアン A′(x) B′(a) C′(b) D′(c)を評価した後は、リバースモードが、多変数実数値関数の勾配をより効率的に評価する方法となるのである。しかし、さらにもっと良い方法がある。そもそも原始関数のヤコビアンをはじめに評価する必要さえないのだ。ヤコビアンは非常に疎なことが多く、行列の積で使うだけである。行列は、結局のところ、線形写像の表現に過ぎないため、それらをインスタンス化するのではなく、線形写像を適用する関数を直接実装すればよい。すなわち,各原始関数 A: \mathbb{R}^M \rightarrow \mathbb{R}^Nヤコビアン A': \mathbb{R}^M \rightarrow \mathbb{R}^{N \times M}に対して、左乗算のヤコビアン-ベクトル積関数(JVP)、 J_A: \mathbb{R}^M \rightarrow (\mathbb{R}^M \rightarrow \mathbb{R}^N )を(キャリーとして)以下のように書ける。

 J_A (x, g) = A'(x)g \tag{2.19}

右乗算のベクター-ヤコビアン積関数(VJP)、 J_A^T: \mathbb{R}^M \rightarrow (\mathbb{R}^N \rightarrow \mathbb{R}^M)は、以下のように書ける。

 J_A^T (x, g) = gA'(x) \tag{2.20}

例えば、各要素を二乗する関数を考える。

 ElemSquare(x) = x \odot x \tag{2.21}

ここで,記号 \odotは要素ごとの乗算を表す。ElemSquareは非常に疎なヤコビアンを持ち、対角線上に 2x、その他の部分に 0を持つ単なる行列である。VJP関数は以下のように与えらる。

 J^T_{ElemSquare} (x, g) = 2g \odot x \tag{2.22}

ヤコビアンは対称行列であるので、左乗算JVP関数 J_{ElemSquare}も同様になる。

図2.2に示すように、JVPを連結することで、順方向と逆方向の両方の微分を実現することができる。順方向モードでは、各ステップで各入力次元ごとにJVPを適用するが、逆方向モードでは、各ステップでVJPを1回だけ適用する。JVPやVJPの評価は、通常、原始関数そのものを評価するよりも小さな定数倍(1〜3)だけ遅くなる。

図2.2:合成関数 F: \mathbb{R}^D \rightarrow \mathbb{R} F = C \cdot B \cdot Aに対するフォアワードモードとリバースモードの微分の違いを図示した。フォアワードモードは入力 xに対する各中間変数のヤコビアン、例えば \frac{\partial a}{\partial x} \frac{\partial b}{\partial x} のような値を積算する。これは、前のステップの積算ヤコビアンと現在の原始関数のヤコビアンで左掛けすることによって行う。リバースモードはヤコビアンを逆向きに乗算することで動作する。これは、前ステップの蓄積されたヤコビアンを現在の原始関数のヤコビアンで右掛けすることにより、中間変数それぞれに対する出力 yヤコビアン \frac{\partial y}{\partial a} \frac{\partial y}{\partial b} といった値を蓄積する。 y \in \mathbb{R} x \in \mathbb{R}^D Dが大きい場合、リバースモードで蓄積される値は Dの因子だけ小さくなり、計算の効率が大幅に向上する。

したがって、フォアワードモードの微分はリバースモードの微分よりも D x \in \mathbb{R}^D)倍遅くなり、リバースモードの微分は合成関数そのものを評価するより小さな定数倍だけ遅くなる。ここで、リバースモードの微分には1つの大きな欠点があることに注意しなければならない。 というのも、リバースパスでJVPを適用する前に、完全なフォワードパスで中間値を計算する必要があるため、すべての中間値をメモリに格納する必要があるからである。このことは、第六章で説明するように、時として大問題となることがある。

 逆積算モードの微分を用いて原始関数の合成鎖の勾配を効率的に計算する方法を説明したが、一般に、合成関数は原始関数の有向非巡回グラフとして記述することができる。幸いなことに、鎖として考えるための戦略は、グラフにも簡単に適用できる。鎖の場合と同様に、関数を評価するために完全なフォアワードパスを計算し、すべての中間値を格納する。次にグラフを逆に走査し、各中間値 zについてヤコビアン-ベクトル積を適用して \frac{\partial y}{\partial z} を計算する。

チェインルールによる合成では発生しないが、対処すべき追加のケースがある。ある値が複数回使用される「ファンアウト」と、関数が複数の入力を受け取る 「ファンイン」である。ファンインは、関数の各引数に対してヤコビアン-ベクトル積関数を定義することで対処する。変数 zの再利用であるファンアウトは、 zを利用する各分岐 iに対して \frac{\partial y}{\partial z}^{(i)} を計算し、その結果を合計して完全な \frac{\partial y}{\partial z} を得ることによって対処する。この場合、グラフを走査する順番に制約が生じ、すべての \frac{\partial y}{\partial z}^{(i)} が利用可能でなければ続行できない。図2.3にこの2つのケースを示す。

リバースモードの微分は、様々な量的分野で独自に何度も発見され[9]、機械学習でも何度か発見されている[114]。逆伝搬の(再)発明は Rumelhartら[93]のものが最も有名である。私自身の理解は、PearlmutterとSiskindの仕事(例えば[79])によって大きく形成されたものである。

図2.3:逆積算モードの計算グラフを図示する。ファンアウト(値 zの再利用)とファンイン(2つの引数を関数 Cに与える)の処理方法を示す。

Reference

[9] Baydin, A. G., Pearlmutter, B. A., & Radul, A. A. (2015). Automatic differentiation in machine learning: a survey. CoRR, abs/1502.05767.

[79] Pearlmutter, B. A. & Siskind, J. M. (2008). Reverse-mode AD in a functional framework: Lambda the ultimate backpropagator. ACM Transactions on Programming Languages and Systems (TOPLAS), 30(2), 7.

[93] Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning repres by back-propagating errors. Nature, 323, 533–536.

[114] Widrow, B. & Lehr, M. A. (1990). 30 years of adaptive neural networks: perceptron, madaline, and backpropagation. Proceedings of the IEEE, 78(9), 1415–1442.

Automatic Differentiation: Deep Learning (DLSS) and Reinforcement Learning (RLSS) Summer School, Montreal 2017

1. Jacobians and the chain rule ―Forward and reverse accumulation

図のように入力 \boldsymbol{x} \in \mathbb{R}^nを受け取り、 y \in \mathbb{R}を出力する関数 Fを考える。 F A B C Dを原始関数とする合成関数である。また、 \boldsymbol{x}を入力した際の各原始関数の出力を、それぞれ、 \boldsymbol{a} = A(\boldsymbol{x})  \boldsymbol{b} = B(\boldsymbol{a}) \boldsymbol{c} = C(\boldsymbol{b}) y = D(\boldsymbol{c})とする。

関数 F微分は、チェインルールにより、 \frac{\partial y}{\partial \boldsymbol{c}} \frac{\partial \boldsymbol{c}}{\partial \boldsymbol{b}} \frac{\partial \boldsymbol{b}}{\partial \boldsymbol{a}} \frac{\partial \boldsymbol{a}}{\partial \boldsymbol{x}} の積で表せる。そして、それらのサイズは、それぞれ、 1 \times ( size \ of \ \boldsymbol{c} ) ( size \ of \ \boldsymbol{c} ) \times ( size \ of \ \boldsymbol{b} ) ( size \ of \ \boldsymbol{b} ) \times ( size \ of \ \boldsymbol{a} ) ( size \ of \ \boldsymbol{a} ) \times \ nである。入力データ \boldsymbol{x}のサイズ nは、大きくなりがちである。その場合、サイズが ( size \ of \ \boldsymbol{a} ) \times \ n \frac{\partial \boldsymbol{a}}{\partial \boldsymbol{x}}も大きくなる。

チェインルールでは計算の順序についての制約はないため、各原始関数のヤコビアンの合成順序は自由である。順積算モードでは、各ヤコビアンを右側から順に合成する。逆積算モードでは、各ヤコビアンを左から合成する。この時、各ヤコビアンの合成では単に行列積を計算することを思い出してほしい(参考:連鎖律(多変数関数の合成関数の微分) | 高校数学の美しい物語)。例えば、順積算モードでは、図のようにまず \frac{\partial \boldsymbol{b}}{\partial \boldsymbol{a}} \frac{\partial \boldsymbol{a}}{\partial \boldsymbol{x}}を計算するが、この計算はサイズ ( size \ of \ \boldsymbol{b} ) \times ( size \ of \ \boldsymbol{a} )とサイズ ( size \ of \ \boldsymbol{a} ) \times \ nの行列積であり、サイズ ( size \ of \ \boldsymbol{b} ) \times \ n \frac{\partial \boldsymbol{b}}{\partial \boldsymbol{x}}となる。その後、左から新たなヤコビアンを掛けていくことになるが、そのたびに得られる積算ヤコビアンは、行方向のサイズが常に nとなる。したがって、通常 nが大きいために、順積算モードでは、計算結果の途中で得られる積算ヤコビアンの形状が大きくなってしまう。一方、逆積算モードでは、 y \in \mathbb{R}であるために、積算ヤコビアンの列方向のサイズは常に 1となり、単なる行ベクトルとなる。ゆえに、逆積算モードの方が計算が楽に、メモリも少なく計算することができ、高速に微分することができる。

順積算モードの微分では、JVPを再帰的に計算することで求めることができ、最初に与えるベクトルをサイズ n \times n単位行列ヤコビアン \frac{\partial \boldsymbol{b}}{\partial \boldsymbol{x}}にして計算を開始し、順次、計算結果である積算ヤコビアンをベクトルに、左のヤコビアンヤコビアンとして新たにJVPを計算していくことで最終的に合成関数 Fの勾配を求めることができる。ここで、実際には計算の開始時に単位行列を行ごとのワンホットベクトルに分け、データ \boldsymbol{x}の各次元ごとにJVPを n回計算する。(図2.2上 参照)

逆積算モードの微分では、VJP積を再帰的に計算することで求めることができ、最初に与えるベクトルをサイズ 1 \times 1単位行列ヤコビアン \frac{\partial \boldsymbol{y}}{\partial \boldsymbol{c}}にして計算を開始し、順次、計算結果である積算ヤコビアンをベクトルに、右のヤコビアンヤコビアンとして新たなVJPを計算していくことで最終的に合成関数 Fの勾配を求めることができる。ここで、逆方向モードでは積算ヤコビアンが常に行ベクトルになるために、順積算モードのように \boldsymbol{x}の各次元ごとにヤコビアンを計算することはせず、一度VJPを適用するだけでよい。(図2.2下 参照)したがって、ナイーブに考えれば、順積算モードと比べて計算量は \frac{1}{n}で済む。

ある値が複数回使用される「ファンアウト」と、関数が複数の入力を受け取る 「ファンイン」では、連鎖が分岐してしまい、計算グラフが分岐してしまうため、特別な対応が必要となる。ファンインは、単に各引数に対してJVPを定義して引数ごとにヤコビアンを求め、単にスタックする。ある変数 xの再利用であるファンアウトは、 xを利用する各分岐 iに対して \frac{\partial y}{\partial x}^{(i)} を計算し、その結果を合計して完全な \frac{\partial y}{\partial x} を得ることによって対処する。ここでは、単純な線形変換を行う G(\boldsymbol{x})を考える。そのVJPを考える場合は、単純に(転置)ベクトルの各成分を足し合わせることで求めることができる。

Diffusionモデル学習記録④ ―Scalable Diffusion Models with Transformers

Preface

 せっかく著者がCC-BYにしてくれていたので、今回はクリスマスプレゼントとして、クリスマス前から巷をにぎわせているScalable Diffusion Models with Transformersを翻訳する。

www.wpeebles.com

arxiv.org

William Peebles, Saining Xie: Scalable Diffusion Models with Transformers, arXiv:2212.09748

©William Peebles, Saining Xie, Originally posted in arXiv(https://arxiv.org/abs/2212.09748), Mon, 19 Dec 2022

License: Creative Commons Attribution 4.0 International (CC-BY)

以下は、原文を翻訳したもので、以下の図はそこから引用したものです。

The following is the translation of the original content and the figures below are retrieved from it.

Scalable Diffusion Models with Transformers

図1:トランスフォーマーバックボーンを用いた拡散モデルは、SOTAの画像クオリティを実現する。 ImageNetの512×512と256×256の解像度でそれぞれで学習した二つのクラス条件付きDiT-XL/2モデルによる選択されたサンプルを示す。

Abstract

 本論文では、トランスフォーマーアーキテクチャーに基づく新しいクラスの拡散モデルについて述べる。 我々は、一般的に使用されるU-Netバックボーンを、潜在パッチ上で動作するトランスフォーマーと置き換えて、画像の潜在拡散モデル(Lattent Diffusion Model)を学習した。 我々は、Gflopsで測定される順方向の複雑さの視点から、我々の拡散トランスフォーマー(DiT)のスケーラビリティを分析した。 その結果、トランスフォーマーの深さや幅、入力トークンの数を増やすことでGflopsを高めたDiTは、一貫してFIDを低く抑えられることが分かった。また、良好なスケーラビリティ特性を有しているのに加え、クラス条件付きImageNet 512×512および256×256ベンチマークにおいて、DiT-XL/2の最大モデルはすべての先行拡散モデルよりも優れた性能を示し、後者においては2.27という最新のFIDを達成した。

1. Introduction

 機械学習トランスフォーマーによってルネッサンスを迎えている。 過去5年間で、自然言語処理[8, 39]、視覚(コンピュータヴィジョン)[10]、その他いくつかの領域のためのニューラルアーキテクチャは、トランスフォーマー[57]にほぼ吸収されてきた。しかし、画像レベルの生成モデルモデルの多くは、このトレンドにの乗れていない。トランスフォーマー自己回帰モデル[3, 6, 40, 44]では広く使われているが、他の生成モデリングフレームワークではあまり採用されていないようである。例えば、拡散モデルは、画像レベルの生成モデル[9,43]における最近の進歩の最前線にも関わらず、それらはすべて、バックボーンの事実上の選択として畳み込みU-Netアーキテクチャを採用している。

 Hoetら[19]は、拡散モデルにU-Netバックボーンを最初に導入した。 この設計上の選択は、自己回帰生成モデルである PixelCNN++ [49, 55] から継承されたものであり、いくつかのアーキテクチャ上の変更がなされている。このモデルは畳み込み型で、主にResNet [15] ブロックで構成されている。標準的なU-Net [46]とは対照的に、トランスフォーマーの重要な構成要素である空間的な自己注意ブロックが追加され、低解像度で散在している。DhariwalとNichol [9]は、条件情報と畳み込み層のチャンネル数を注入するために、適応的正規化層[37]を使用するなど、U-Netのアーキテクチャの選択についていくつか検討している。しかし、HoetらのU-Netの高いレベルのデザインはほぼそのまま残ったままである。

 この研究で、我々は拡散モデルにおけるアーキテクチャ選択の重要性を解明し、将来の生成モデル研究のための経験的なベースラインを提供することを目的とする。その結果、U-Netの帰納バイアスは拡散モデルの性能にとって重要ではなく、トランスフォーマーのような標準的な設計で容易に置き換えることができることが分かった。その結果、拡散モデルは、最近のアーキテクチャーの統一化というトレンドの恩恵を受けるのに適していることが分かった。例えば、他のドメインのベストプラクティスやトレーニングレシピを継承し、スケーラビリティ、ロバスト性、効率性などの好ましい特性を保持することができる。しかも、 標準化されたアーキテクチャは、分野横断的な研究の新たな可能性を開くものでもある。

 本論文では、トランスフォーマーに基づく新しい拡散モデルに焦点を当てる。我々はそれらをDiffusion Transformers、または短くDiTsと呼ぶ。DiTsは、従来の畳み込みネットワーク(例えば、ResNet [15])よりも視覚認識に対して効果的にスケーリングすることが示されているVision Transformers (ViTs) [10]のベストプラクティスに準拠している。

 より具体的には、我々は、ネットワークの複雑さ vs. サンプルの質に関するトランスフォーマーのスケーリング挙動を研究する。我々は、拡散モデルがVAEの潜在空間内で学習されるLatent Diffusion Models(LDMs) [45]の枠組みの下でDiTデザイン空間を構築しベンチマークすることにより、U-Netバックボーンをトランスフォーマーで置き換えることに成功することを示す。さらに、DiTが拡散モデルのためのスケーラブルなアーキテクチャであることを示す。というのも、ネットワークの複雑さ(Gflopsで測定)vs. サンプルの品質(FIDで測定)の間に強い相関があるのである。DiTをスケールアップし、大容量のバックボーン(118.6 Gflops)を持つLDMをトレーニングするだけで、クラス条件付き256×256ImageNet生成ベンチマークにおいて2.27FIDというSOTAを達成することができた。

Transformers

 Transformers [57]は言語、視覚(コンピュータヴィジョン)[10]、強化学習 [5, 23]、メタ学習 [36]において、ドメイン固有のアーキテクチャを置き換えてきた。また、言語領域[24]、汎用自己回帰モデル[17]、ViT[60]において、モデルサイズ、学習量、データが増加する中で顕著なスケーリング特性を示した。言語を超えて、トランスフォーマーピクセルを自己回帰的に予測するために訓練されてきた[6, 7, 35]。 また、自己回帰モデル[11, 44]とマスクされた生成モデル[4, 14]として離散コードブック[56]で学習され、前者は20Bパラメータまでの優れたスケーリング挙動を示した[59]。最後に、トランスフォーマーは非空間データを合成するためにDDPMで研究されてきた。例えば、DALL-E 2のCLIP画像埋め込みを生成するためである[38, 43]。本論文では、画像の拡散モデルのバックボーンとして使用された場合のトランスフォーマーのスケーリング特性について研究する。

Denoising diffusion probabilistic models (DDPMs)

 拡散モデル[19, 51]とスコアベース生成モデル[22, 53]は画像の生成モデルとして特に成功し、多くの場合、それまで最先端であったGAN[12]を凌駕している。 過去2年間のDDPMの改良は、サンプリング技術の改良[19, 25, 52]、特に分類器放棄誘導[21]、画素ではなくノイズを予測する拡散モデルの再構成[19]、アップサンプラーと並行して低解像度ベース拡散モデルを学習するカスケードDDPMパイプラインの使用により大きく推進されている[9, 20]。上記のすべての拡散モデルにおいて、畳み込みU-Net [46]がバックボーンアーキテクチャデファクトスタンダードとなっている。

Architecture complexity

 画像生成に関する文献では、アーキテクチャの複雑さを評価する際、パラメータカウントを使用することが一般的である。 一般に、パラメータカウントは、例えば、性能に大きな影響を与える画像解像度を考慮していないため、画像モデルの複雑さの代用としては不十分な場合がある[41, 42]。その代わり、本論文では、モデルの複雑さの分析の多くをGflopsの理論の視点で行っている。このことは、Gflopsが複雑さの測定に広く使用されているアーキテク チャ設計の文献と一致している。実際のところ、最良の複雑さの指標(が何であるか)は、それが特定のアプリケー ションシナリオに頻繁に依存するため、まだ議論が続いている。NicholとDhariwalの拡散モデルの改善に関する代表的な研究[9, 33]は、我々と最も関連が深く、彼らはそこで、U-NetアーキテクチャクラスのスケーラビリティとGflop特性を分析した。この論文では、トランスフォーマークラスに焦点を当てる。

Diffusion Transformers

3.1. Preliminaries

Diffusion formulation

 我々のアーキテクチャを紹介する前に、拡散モデル(DDPMs)[19, 51]を理解するために必要ないくつかの基本的な概念を簡単に復習しておく。ガウスシアン拡散モデルは、定数 \bar{\alpha}_tは超パラメーターとして、実データ x_0: q(x_t | x_0) = \mathcal{N} \left( x_t; \sqrt{\bar{\alpha}_t}x_0, (1- \bar{\alpha}_t ) \boldsymbol{\rm{I}} \right)に徐々にノイズを加える順方向ノイズプロセスを仮定する。再パラメータ化のトリックを適用すると、サンプル x_t =  \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1- \bar{\alpha}_t}\epsilon_t, \ where \ \epsilon_t \sim \mathcal{N} \left( 0, \boldsymbol{\rm{I}} \right)が得られる。    拡散モデルは、順プロセスでの劣化を反転させる逆プロセス p_{\theta} ( x_{t-1} | x_t ) = \mathcal{N} ( \mu_{\theta}(x_t) , {\sum}_{\theta} (x_t) )を学習するために訓練され、ニューラルネットワーク p_{\theta}の統計量を予測するために使用される。 逆方向モデルは x_0の対数尤度の変分下界[27]で学習され、 \mathcal{L} \left( \theta \right) = - p (x_0 | x_1 ) + {\sum}_t \mathcal{D}_{KL}
 \left( q^{\ast} (x_{t-1} | x_t, x_0 ) || p_{\theta} ( x_{t-1} | x_t ) \right) を減少させ、学習に無関係な付加項が除かれる。 q^{*} p_{\theta}はともにガウス分布であるため、 \mathcal{D}_{KL}は2つの分布の平均と共分散で評価することが可能である。 \mu{_{\theta}}をノイズ予測ネットワーク \epsilon{_{\theta}}として再パラメータ化することにより、予測ノイズ \epsilon_{\theta} (x_t )とサンプリングしたground-truthガウスノイズ \epsilon_tとの単純な平均二乗誤差 \mathcal{L}_{simple} (\theta) = || \epsilon_{\theta} (x_t) - \epsilon_t ||^2_2 を用いてモデルを学習することが可能である。しかし、逆プロセスでの共分散 {\sum}_{\theta}をフルに学習して拡散モデルを学習するためには、 \mathcal{D}_{KL}項を最適化する必要がある。そこで、NicholとDhariwal のアプローチ[33]に従い、 \epsilon_{\theta} \mathcal{L}_{simple}で学習し、 {\sum}_{\theta} \mathcal{L}フルに使い学習する。一度 p_{\theta}が学習されると、 x_{t_max} \sim \mathcal{N} \left( 0,\boldsymbol{\rm{I}} \right) を初期化し、再パラメータ化のトリックにより x_{t-1} \sim p_{\theta} (x_{t-1} | x_t )をサンプリングすることにより、新たな画像を得ることができる。

Classifier-free guidance

 条件付き拡散モデルは、クラスラベル cのような追加的な情報を入力として受け取ることができる。この場合、逆プロセスは p_{\theta} (x_{t-1} | x_t, c)となり、 \epsilon_{\theta} {\sum}_{\theta} cに条件づけられる。 この設定において、分類器放棄誘導は、 \log p (c | x )が高くなるような xを見つけるようにサンプリング手順を促すために使用することができる[21]。ベイズ則では、 \ log p(c|x) \propto \log p(x|c) - \log p(x) であるので、 \nabla_x \log p(c|x) \propto \nabla_x \log p(x|c) - \nabla_x \log p(x)となる。このように拡散モデルの出力をスコア関数として解釈することで、DDPMのサンプリング手順を、 p(x|c) が高いサンプル xに誘導することができる。このことは次のような式で表せる。  \hat{\epsilon}{_{\theta}} \left( x{_t}, c \right) = \epsilon{_{\theta}} \left( x{_t}, {\emptyset} \right) + s \cdot \nabla{_x} \log p \left( x | c \right) \propto \epsilon{_{\theta}} \left( x{_t}, {\emptyset} \right) + s \cdot \left( \epsilon{_{\theta}} \left( x{_t} , c \right) - \epsilon{_{\theta}} \left( x{_t}, {\emptyset} \right) \right) 、ここで s >1は誘導規模を表す( s=1は標準サンプリングであることに注意する。)。 c={\emptyset}を用いた拡散モデルの評価は、学習中にランダムに cを削除し、学習した「ヌル」埋め込み {\emptyset}に置き換えることによって行われる。 分類器放棄誘導は、一般的なサンプリング技術よりも大幅に改善されたサンプルをもたらすことが広く知られており[21, 32, 43]、この傾向は我々のDiTモデルでも同様である。

図2:Diffusion Transformer (DiT)を用いたImageNetの生成。バブル領域は、拡散モデルのフロップを示す。 左:400K学習イタレーションにおけるDiTモデルのFID-50K(低いほど良い)。 モデルのフロップ数が増加するにつれて、FIDの性能は着実に向上している。 右:我々の最も良いモデルである、DiT-XL/2は計算効率が高く、ADMやLDMといった先行するU-Netベースの拡散モデルを凌駕している。

Latent diffusion models

 高解像度の画素空間で拡散モデルを正しく学習させることは、計算量的に困難である。潜在拡散モデル(LDM)[45]は、2段階のアプローチでこの問題に取り組む。(1)学習したエンコーダ Eで画像をより小さな空間表現に圧縮するオートエンコーダを学習する、(2)画像 xの拡散モデルの代わりに、表現 z=E(x)の拡散モデルを学習する。( Eは固定されているとする。)図2に示すように、LDMはADMのようなピクセル空間拡散モデルの数分の一のGflopsで良好な性能を達成することができる。新しい画像は、拡散モデルから表現 zをサンプリングし、学習したデコーダ x=D(z)を用いて画像にデコードすることで生成できる。

 図2に示すように、LDMはADMのようなピクセル空間拡散モデルの数分の1のGflopsで良好な性能を達成することができる。 本論文では、DiTsを潜在空間に適用しているが、そのままピクセル空間に適用することも可能である。このため、我々の画像生成パイプラインは、既製の畳み込みVAEとトランスフォーマーベースのDDPMを用いたハイブリッド型のアプローチとなっている。

図3:拡散トランスフォーマー(DiT)アーキテクチャ。左:条件付き潜在DiTモデルを学習する。入力潜在変数はパッチに分解され、複数のDiTブロックによって処理される。右:我々のDiTブロックの詳細。適応的レイヤ正規化、クロスアテンション、追加入力トークンによる条件付けを組み込んだ様々な標準的なトランスフォーマーブロックで実験した。適応的レイヤノルムが最も効果的であった。

3.2. Diffusion Transformer Design Space

 我々は、拡散モデルのための新しいアーキテクチャである拡散トランスフォーマー(DiTs)を紹介する。我々は、スケーリング特性を保持するために、標準的な変換器アーキテクチャにできるだけ忠実であることを目指している。我々は画像(特に画像の空間表現)のDDPMの学習をすることに焦点を当てているため、DiTはパッチのシーケンスに対して動作するVision Transformer(ViT)アーキテクチャをベースにしている[10]。DiTは、ViTのベストプラクティスの多くを継承している。 図3は、DiTのアーキテクチャの概要を示している。 本節では、DiTの順方向過程と、DiTクラスの設計空間の構成要素を説明する。

図4:DiTの入力仕様。パッチサイズ p×pが与えられたとき、形状 I×I×Cの空間表現(VAEからのノイズ潜在変数)は、隠れ次元 dで長さ[tex: T= (\frac{I}{p})2]のシーケンスに「パッチ」化される。パッチサイズ pが小さいとシーケンス長が長くなり、Gflopsが増加する。

Patchify

 DiTの入力は空間表現 z(256×256×3画像の場合、 zは32×32×4の形状を持つ)である。DiTの最初の層は "patchify "であり、空間入力を、各パッチを線形に埋め込むことによって、各次元 d Tトークンのシーケンスに変換する。 パッチ化の後、標準的なViT周波数ベースの位置埋め込み(サイン・コサイン版)をすべての入力トークンに適用する。パッチ化で生成されるトークンの数は、パッチサイズのハイパーパラメータ pによって決定される。 図4に示すように、 pを半分にすると Tは4倍になり、したがってトランスフォーマーの総Gflopsは少なくとも4倍になる。Gflopsに大きな影響を与えるが、 pを変更しても下流のパラメータ数には意味がないことに注意してください。DiTの設計空間に p=2,4,8を追加した。

DiT block design

 パッチファイに続いて、入力トークンは一連のトランスフォーマーブロックによって処理される。 ノイズを含む画像入力に加え、拡散モデルはノイズの時間ステップ、クラスラベル、自然言語などの追加の条件情報を処理することがある。我々は、条件付き入力を異なる方法で処理する4種類のトランスフォーマーブロックを検討した。 これらの設計は、標準的なViTブロックの設計に、小さいが重要な変更を導入している。 すべてのブロックの設計を図3に示す。

  • 文脈内条件付け:入力シーケンスに2つの追加トークンとして t cの埋め込みベクトルを単純に追加し、画像トークンと同様に扱う。これはViTsのclsトークンと同様であり、標準的なViTブロックをそのまま使用することができる。最終ブロックの後、シーケンスから条件トークンを除去する。このアプローチでは、モデルへの新たなGflopsの導入はごくわずかである。

  • クロスアテンションブロック:画像トークン列とは別に、 t cの埋め込みを長さ2のシーケンスに連結する。トランスフォーマーブロックは、マルチヘッドセルフアテンションブロックの後に、マルチヘッドクロスアテンション層を追加するように変更されている。これは、Vaswaniet et al.[57]のオリジナルの設計と同様であり、またLDMがクラスラベルの条件付けのために用いるものと同様である。クロスアテンションは最も多くのGflopsをモデルに追加し、およそ15%のオーバーヘッドとなる。

  • 適応的レイヤー正規化( adaLN)ブロック:GAN [2, 26] や U-Netバックボーンによる拡散モデル [9] において、適応的正規化層 [37] が広く用いられていることを受けて、我々は、トランスフォーマーブロックの標準的なレイヤー正規化層を適応的レイヤー正規化(adaLN)に置き換えることを検討する。次元ごとのスケールとシフトのパラメータ \gamma \betaを直接学習するのではなく、埋め込みベクトル t cの和からそれらを回復させるのである。我々が探求する三つのブロックの設計の中では、adaLNがGflopsを最も加えないので、最も計算効率が良い。さらに、全てのトークンに同じ関数を適用するよう再制約された唯一の条件付け機構である。

-addLN-Zeroブロック:ResNetsに関する先行研究では、各残差ブロックを恒等関数として初期化することが有効であることが分かっている。例えば、Goyalらは、各ブロックの最終バッチノルムスケールファクター \gammaをゼロに初期化することで、強化学習設定における大規模学習が加速されることを発見した[13]。拡散U-Netモデルも同様の初期化戦略を用いており、各ブロックの最後の畳み込み層を、任意の残留接続の前にゼロ初期化する。我々は、同じことを行うAdaLN DiTブロックの改良を探求している。また、 \gamma \betaの回帰に加えて、DiTブロック内の任意の残留接続の直前に適用される次元単位のスケーリングパラメータ \alphaも回帰させる。すべての \alphaに対してゼロベクトルを出力するようにMLPを初期化し,これにより,完全なDiTブロックを恒等関数として初期化する。バニラAdaLNブロックと同様に、AdaLN-Zeroはモデルへの無視できるほどのGflopsを増やす。

我々はDiTの設計空間には、文脈内条件付け、クロスアテンション、適応的レイヤー正規化、adaLN-Zeroブロックが含めた。

表1:DiTモデルの詳細。ViT[10]モデルのスモール(S)、ベース(B)、ラージ(L)の構成に準じ、最大のモデルとしてXラージ(XL)構成を導入している。

Model size

 我々は、それぞれ潜在次元 dで動作する N個のDiTブロックのシーケンスを適用する。ViTに倣い、 N dとアテンションヘッドを共同でスケールする標準的なトランスフォーマー構成を用いる[10, 60]。具体的には、4つの構成、DiT-S、DiT-B、DiT-L、DiT-XLである。0.3から118.6Gflopsの幅広いモデルサイズとフロップ数で構成されており、スケーリング性能を測定することができる。表1には、構成の詳細を示す。

 我々は、DiTの設計空間にB、S、L、XLの構成を追加した。

Transformer decoder

 最後のDiTブロックの後、一連の画像トークンを出力ノイズの予測と出力対角共分散の予測にデコードする必要がある。これらの出力は、いずれも元の空間入力と同じ形状をしている。そのために、標準的な線形デコーダを使用している。最終層の正規化(adaLNの場合は適応的)を適用し、各トークンを p×p×2Cテンソル(ここで CはDiTへの空間入力のチャンネル数)に線形復号する。最後に、デコードされたトークンを元の空間レイアウトに並べ替えて、予測されるノイズと共分散を得る。  我々が探求するDiTの完全な設計空間は、パッチサイズ、トランスフォーマーブロックアーキテクチャとモデルサイズである。

4. Experimental Setup

 DiTのデザイン空間を探索し、モデルクラスのスケーリング特性を研究する。モデルの名前は、その構成と潜在パッチの大きさに従って付けられている。 例えば、DiT-XL/2はXLarge構成で p=2である。

Training

 本論文では、高い競争力を持つ生成モデリングベンチマークであるImageNetデータセット[28]を用いて、256×256および512×512画像解像度でクラス条件付き潜在DiTモデルの学習を行った。最終線形層はゼロで初期化し、それ以外はViTの標準的な重み初期化手法を使用する。全てのモデルはAdamW [27, 30]で学習する。  学習率は 1×10^{-4}で、重みの除去は行わず、バッチサイズは256である。データ拡張に用いたのは水平フリップのみである。ViTsの先行研究[54, 58]とは異なり、DiTsを高性能に学習させるために、学習率のウォームアップや正則化は必要ないと考えた。これらの手法を用いない場合でも、学習は全てのモデル構成で非常に安定しており、トランスフォーマーを学習する際によく見られる損失スパイクも観察されなかった。生成モデリングの文献によく見られるように、我々はDiTの重みの指数移動平均(EMA)を0.9999の減衰率で維持しながら学習を進めた。すべての結果はEMAモデルを用いている。学習用ハイパーパラメータはADMからほぼ完全に保持されている。DiTモデルのサイズとパッチサイズに関わらず同一の学習ハイパーパラメータを使用した。学習率、減衰/ウォームアップスケジュール、Adamβ1/β2、重み減衰は調整しなかった。

Diffusion

 Stable Diffusion [45]の既製の学習済みVariational Autoencoder (VAE)モデル[27]を使用した。VAEエンコーダは、形状256×256×3のRGB画像 xを8倍ダウンサンプルし、 z=E(x)は形状32×32×4とする。本節の全ての実験において、我々の拡散モデルはこの Z空間で動作する。 拡散モデルから新しい潜在変数をサンプリングした後、VAEデコーダ x=D(z)を用いてそれをピクセルにデコードする。具体的には、 1×10^{-4}から 2×10^{-2}の範囲の t_{max}=1000の線形分散スケジュール、ADMの共分散のパラメータ化 {\sum}_{\theta}、入力タイムステップとラベルを埋め込む方法を用いた。

Evaluation metrics

 スケーリング性能は、画像の生成モデルを評価するための標準的な指標であるFID(Frechet Inception Distance)[18]を用いて測定する。  先行研究との比較では、慣例に従い、250 DDPMのサンプリングステップを用いたFID-50Kを報告する。FIDは小さな実装のディテールに敏感であることが知られている[34]が、正確な比較を確実にするために、この論文で報告されたすべての値は、サンプルをエクスポートし、ADMのTensorFlow評価[9]を使用して得られたものである。 このセクションで報告されたFID数は、特に明記された場合を除き、分類器放棄誘導を使用していない。また、二次評価指標としてInception Score [48]、sFID [31]、 Precision/Recall [29]を報告している。

Compute

 すべてのモデルをJAX[1]で実装し、TPU-v3ポッドを用いて学習させた。最も計算量の多いDiT-XL/2は、TPU v3-256ポッドを用いて、グローバルバッチサイズ256で約5.7回/秒の速度で学習する。

図5:様々な条件付けの戦略を比較。adaLN-Zeroは、トレーニングのすべての段階において、クロスアテンション条件付けとインコンテキスト条件付けを凌駕している。

5. Experiments

DiT block design

 Gflop DiT-XL/2モデルのうち、最高レベルの4つのモデルを、それぞれ異なるブロックデザインー文脈内条件付け(119.4 Gflops)、クロスアテンション(137.6 Gflops)、適応的レイヤー正規化(AdaLN、118.6 Gflops)、AdaLN-zero(118.6 Gflops)で学習させた。トレーニングの過程でFIDを測定しており、図5はその結果である。 adaLN-Zeroブロックはクロスアテンションや文脈内条件付けよりもFIDが低く、計算効率も最も良い。400Kのイタレーションにおいて、adaLN-Zeroモデルで達成されたFIDは文脈内条件付けのほぼ半分であり、条件付けメカニズムがモデルの品質に決定的な影響を与えることが証明された。初期化も重要であり、各DiTブロックを恒等関数として初期化するadaLN-Zeroは、バニラadaLNを大幅に上回る性能を示した。本稿では、これ以降すべてのモデルにadaLN-Zero DiTブロックを使用する。

図6:DiTモデルをスケールアップすることで、全ての学習段階においてFIDが改善される。12個のDiTモデルの学習イタレーションにおけるFID-50Kを示す。 上段:パッチサイズを一定にした場合のFIDの比較。下段:モデルサイズを一定にした場合のFIDを比較。トランスフォーマーバックボーンをスケールアップすることで、全てのモデルサイズ、パッチサイズにおいて、より優れた生成モデルが得られる。

Scaling model size and patch size

 モデル構成(S, B, L, XL)とパッチサイズ(8, 4, 2)にわたって、12のDiTモデルを学習させた。DiT-LとDiT-XLは、他の構成に比べて相対的なGflopsの点で互いにかなり接近していることに注意してください。 図2(左)は、40万回の学習イタレーションにおける各モデルのGflopsとFIDの概要を示している。いずれの場合も、モデルサイズを大きくし、パッチサイズを小さくすることで、拡散モデルがかなり改善されることがわかった。

 図6(上)は、モデルサイズを大きくし、パッチサイズを一定にした場合のFIDの変化を示している。4つの構成すべてにおいて、トランスフォーマーを深く、広くすることによって、学習のすべての段階でFIDが大幅に改善されることがわかる。同様に、図6(下)は、パッチサイズを小さくし、モデルサイズを一定にしたときのFIDを示したものである。このように、DiTのパラメータをほぼ固定したまま、処理するトークン数をスケールアップさせるだけで、学習全体を通してFIDが大幅に改善されることがわかる。

図8:トランスフォーマーのGflopsはFIDと強い相関がある。各DiTモデルのGflopsと、400K学習ステップ後の各モデルのFID-50Kをプロットしている。

DiT Gflops are critical to improving performance

 図6の結果は、DiTモデルの品質を決定する上で、パラメータ数はほとんど重要でないことを示唆している。モデルサイズを一定に保ち、パッチサイズを小さくすると、トランスフォーマーの全パラメータは実質的に変化せず、Gflopsのみが増加する。これらの結果は、モデルGflopの拡張が性能向上の鍵であることを示している。さらに調査するために、図8に400K学習ステップのFID-50KをモデルGflopsに対してプロットした。その結果、サイズやトークンの異なるDiTモデルでも、総Gflopsが同程度であれば、最終的に同じようなFID値が得られることが分かった(例えば、DiT-S/2とDiT-B/4)。実際、モデルのGflopsとFID-50Kの間には強い負の相関が見られ、モデルの追加計算がDiTモデルの改良に不可欠な要素であることが示唆されている。図12 (付録) では、この傾向が Inception Score などの他の指標にも見られることが分かる。

図9:大きなDiTモデルは大きな計算量資源をより効率的に使用する。 FIDを総トレーニング量に対する関数としてプロットした。

Larger DiT models are more compute-efficient

 図9では、すべてのDiTモデルについて、FIDを総トレーニング量に対する関数としてプロットしている。学習量は、モデルGflop・バッチサイズ・学習ステップ・3とし、係数3は、バックパスがフォワードパスの2倍の計算量になることをおおよそ示している。その結果、小さなDiTモデルは、たとえ長く学習しても、より少ないステップで学習した大きなDiTモデルに比べて、最終的に計算効率が悪くなることが分かった。同様に、パッチサイズ以外は同一であるモデルは、学習Gflopsを制御した場合でも、性能のプロファイルが異なることがわかる。 例えば、XL/4は約 10^{10} GflopsでXL/2より性能が向上する。

図7:トランスフォーマーフォワードパスのGflopsを増加させると、サンプルの品質が向上する。拡大表示でご覧ください。同じ入力潜在ノイズとクラスラベルを用いて400Kの学習を行った後、12個のDiTモデルすべてからサンプリングしている。 トランスフォーマーの深さ/幅を増やすか、入力トークンの数を増やすかして、モデルのGflopsを増加させると、大幅に改善される。

Visualizing scaling

 図7は、スケーリングがサンプルの品質に与える影響を可視化したものである。400K学習ステップにおいて、12個のDiTモデルのそれぞれから、同じ(identical)開始ノイズ x_{t_{max}}、サンプリングノイズ、クラスラベルを用いて画像をサンプリングする。これにより、スケーリングがDiTサンプルの品質にどのように影響するかを視覚的に解釈することができる。実際、モデルサイズとトークン数の両方を拡張することで、視覚的品質が顕著に向上する。

表2:クラス条件付き画像生成のベンチマークをImageNet 256×256で実施。DiT-XL/2はSOTAのFIDを実現した。

5.1. State-of-the-Art Diffusion Models

256×256 ImageNet.

 スケーリング解析に続いて、最も高いGflopを持つモデルDiT-XL/2を7Mステップ学習させる。 このモデルのサンプルを図1に示し、SOTAのクラス条件付き生成モデルとの比較を行う。 また、表2に結果を示す。 DiT-XL/2は、分類器放棄誘導を用いた場合、すべての事前拡散モデルを上回り、LDMが達成したFID-50Kの最高値3.60を2.27に減少させることができた。図2(右)は、DiT-XL/2(118.6 Gflops)がLDM-4(103.6 Gflops)などの潜在空間U-Netモデルに対して計算効率が高く、ADM(1120 Gflops)やADM-U(742 Gflops)などのピクセルスペースU-Netモデルよりも実質的に効率が良いことを示している。本手法は、先行するSOTAであったStyleGAN-XL[50]を含むすべての生成モデルの中で最も低いFIDを達成した。最後に、DiT-XL/2はLDM-4やLDM-8と比較して、テストしたすべての分類器放棄誘導のスケールでより高いリコール値を達成することが確認された。2.35Mステップ(ADMと同様)のみで学習した場合でも、XL/2は2.55のFIDでいまだ全ての先行拡散モデルを凌駕している。

表3:ImageNet 512×512におけるクラス条件付き画像生成のベンチマーク。 先行研究[9]では、512×512の解像度で1000個の実サンプルを用いてPrecisionとRecallを測定していることに注意。一貫性を保つために、我々も同じことをしている。

512×512 ImageNet

 512×512解像度のImageNetに対して、256×256モデルと同じハイパーパラメータで3Mイタレーションの新しいDiT-XL/2モデルを学習させる。 パッチサイズ2の場合、64×64×4入力の潜在能力をパッチ処理した後、このXL/2モデルは合計1024トークンを処理する(524.6 Gflops)。表3は、SOTA手法との比較である。XL/2は、この解像度においても、ADMが達成した従来の最高値3.85のFIDを3.04に改善し、すべての先行する拡散モデルを再び上回った。トークンの数が増えても、XL/2は計算効率を維持したままである。 例えば、ADMでは1983 Gflops、ADM-Uでは2813 Gflops、XL/2では524.6 Gflopsを使用している。 高解像度XL/2モデルのサンプルを図1および付録に示す。

図10:サンプリングの計算量が多くても、モデルの計算量が少なければ補うことはできない。 400Kイタレーションで学習したDiTモデルそれぞれについて、[16、32、64、128、256、1000]のサンプリングステップを使用してFID-10Kを計算した。各ステップ数について、FIDと、各画像のサンプリングに使用された総Gflopsをプロットする。 小型モデルは、大型モデルより多くのテスト時間Gflopsでサンプリングしても、大型モデルとの性能差を縮めることができない。

5.2. Model Compute vs. Sampling Compute

 多くの生成モデルと異なり、拡散モデルは、画像を生成する際のサンプリングステップ数を増やすことで、学習後に追加の計算を行うことができる点が特徴である。このセクションでは、モデルGflopsのサンプル品質の重要性を考慮し、より小さなモデル計算DiTが、より多くのサンプリング計算を使用することによって、より大きなものを上回ることができるかどうかを研究する。 我々は、400Kの学習ステップの後、12個のDiTモデルすべてについて、1画像あたり[16, 32, 64, 128, 256, 1000]のサンプリングステップを用いて、FIDを計算した主な結果は図10に示す通りである。1000サンプリングステップのDiT-L/2と128サンプリングステップのDiT-XL/2を比較する。この場合、L/2は各画像のサンプリングに80.7 Tflopsを使用し、XL/2は各画像のサンプリングに5倍少ない15.2 Tflopsの計算を使用する。それにもかかわらず、FID-10KはXL/2の方が優れている(23.7 vs 25.9)。一般に、サンプリング計算ではモデル計算の不足を補うことはできない。

6. Conclusion

 本論文では、従来のU-Netモデルよりも優れた性能を持ち、トランスフォーマーモデルクラスの優れたスケーリング特性を継承した、シンプルなトランスフォーマーベースの拡散モデル用バックボーンである、Diffusion Transformers (DiTs) を紹介した。 本論文で得られた有望なスケーリング結果を考慮すると、今後、より大きなモデルやトークン数にDiTsをスケーリングする研究を継続する必要がある。また、DiTはDALL-E 2やStable Diffusionのようなtext-to-imageモデルのためのドロップインバックボーンとして検討される可能性がある。

Acknowledgements.

 Kaiming He、Ronghang Hu、Alexander Berg、Shoubhik Debnath、Tim Brooks、Ilija RadosavovicそしてTete Xiaoに有益な議論を頂いたことに感謝する。William Peeblesは、NSF GRFPの支援を受けている。

Apendix (The figures in the appendix are only part of the original contents.)

*付録の図は原文の一部のみを掲載しています。

図11:512×512と256×256解像度のDiT-XL/2モデルからの追加サンプル。512×512モデルで6.0、256×256モデルで4.0という分類放棄誘導スケ ールを使用している。両モデルともft-EMA VAEデコーダを使用している。

表4:DiTモデルの詳細。本論文では、すべてのDiTモデルについて詳細な情報を報告する。なお、ここでのFID-50Kは分類器放棄誘導なしで計算されている。 パラメータとフロップ数は、エンコーダとデコーダで84Mのパラメータを含むVAEモデルを除いている。 256×256および512×512 DiT-XL/2モデルともに、FIDが飽和することはなく、可能な限り学習を継続した。この表で報告されている数値は、ft-MSE VAEデコーダーを使用している。

表6:U-Netバックボーンを使用したベースライン拡散モデルのGflop数。

A. Additional Implementation Details

 表4には、256×256および512×512の両モデルを含む、すべてのDiTモデルに関する情報を記載している。Gflop数、パラメータ、トレーニングの詳細、FIDなどを含む。 また、表6にはADMとLDMのDDPM U-NetモデルのGflopカウントも含まれている。

DiT model details

 入力タイムステップを埋め込むために、我々は256次元周波数埋め込み[9]と、トランスフォーマーの隠れサイズとSiLU活性化関数に等しい次元を持つ2層MLPを使用している。各adaLN層はタイムステップとクラス埋め込みの和をSiLU非線形と線形層に送り、出力ニューロントランスフォーマーの隠れサイズに等しい4×(adaLN)か6×(adaLN-Zero)のいずれかに等しい。コアトランスフォーマーにはGELU非線形(tanhで近似)を用いる[16]。

表5:デコーダ除去。https://huggingface.co/stabilityai/sd-vae-ft-mseで入手可能な様々な事前学習済みVAEデコーダの重みをテストした。ImageNet 256×256において、異なる事前学習済みデコーダの重みで同等の結果を得ることができた。

B. VAE Decoder Ablations

 この実験では、既製の事前学習済みVAEを使用した。VAEモデル(ft-MSEとft-EMA)は、オリジナルのLDM「f8」モデルをファインチューンしたものである(デコーダの重みのみをファインチューン)。第5節のスケーリング解析では、ft-MSE デコーダを使用して指標を監視し、表2と表3に示した最終指標では、ft-EMA デコーダを使用した。このセクションでは、LDMで使用されるオリジナルのVAEデコーダと、Stable Diffusionで使用される2つのファインチューンされたデコーダの3つの異なる選択肢について説明する。エンコーダはモデル間で同一であるため、拡散モデルを再学習することなくデコーダを交換することができる。表5に結果を示す。LDMデコーダを用いた場合、XL/2はすべての先行拡散モデルを凌駕している。

C. Model Samples

 DiT-XL/2の3Mステップと7Mステップで学習した512×512と256×256解像度のモデルのサンプルを示す。図1および図11は、両モデルから選択されたサンプルである。図13から図32は、分類器放棄誘導スケールと入力クラスラベル(250 DDPMサンプリングステップとft-EMA VAEデコーダで生成)の範囲における2つのモデルのサンプルを切断したものである。その結果、スケールを大きくすると、視覚的な忠実度が増し、サンプルの多様性が低下することがわかった。

図12:DiTのスケーリング挙動といくつかの生成モデリング指標。左:FID、sFID、Inception Score、Precision、Recallについて、モデル性能を総トレーニング量の関数としてプロットしている。右:12種類のDiTについて、400Kトレーニングステップにおけるモデル性能を、トランスフォーマーのGflopsに対してプロットしたところ、メトリクス間で強い相関があることが分かった。 すべての値は、ft-MSE VAEデコーダを使用して計算された。

References

[1] JamesBradbury,RoyFrostig,PeterHawkins,Matthew James Johnson, Chris Leary, Dougal Maclau-rin, George Necula, Adam Paszke, Jake VanderPlas, SkyeWanderman-Milne, and Qiao Zhang.JAX: composabletransformations of Python+NumPy programs, 2018. 6

[2] Andrew Brock, Jeff Donahue, and Karen Simonyan. Largescale GAN training for high fidelity natural image synthesis.InICLR, 2019. 5, 9

[3] Tom B Brown, Benjamin Mann, Nick Ryder, Melanie Sub-biah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan,Pranav Shyam, Girish Sastry, Amanda Askell, et al. Lan-guage models are few-shot learners. InNeurIPS, 2020. 1

[4] Huiwen Chang, Han Zhang, Lu Jiang, Ce Liu, and William TFreeman. Maskgit: Masked generative image transformer. InCVPR, pages 11315–11325, 2022. 2

[5] Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee,Aditya Grover, Misha Laskin, Pieter Abbeel, Aravind Srini-vas, and Igor Mordatch. Decision transformer: Reinforce-ment learning via sequence modeling. InNeurIPS, 2021. 2

[6] Mark Chen, Alec Radford, Rewon Child, Jeffrey Wu, Hee-woo Jun, David Luan, and Ilya Sutskever. Generative pre-training from pixels. InICML, 2020. 1, 2

[7] Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever.Generating long sequences with sparse transformers.arXivpreprint arXiv:1904.10509, 2019. 2

[8] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and KristinaToutanova. Bert: Pre-training of deep bidirectional trans-formers for language understanding. InNAACL-HCT, 2019.1

[9] Prafulla Dhariwal and Alexander Nichol. Diffusion modelsbeat gans on image synthesis. InNeurIPS, 2021. 1, 2, 3, 5,6, 9, 12

[10] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov,Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner,Mostafa Dehghani, Matthias Minderer, Georg Heigold, Syl-vain Gelly, et al. An image is worth 16x16 words: Trans-formers for image recognition at scale. InICLR, 2020. 1, 2,4, 5

[11] Patrick Esser, Robin Rombach, and Bj ̈orn Ommer. Tamingtransformers for high-resolution image synthesis, 2020. 2

[12] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, BingXu, David Warde-Farley, Sherjil Ozair, Aaron Courville, andYoshua Bengio. Generative adversarial nets. InNIPS, 2014.3

[13] Priya Goyal, Piotr Doll ́ar, Ross Girshick, Pieter Noord-huis, Lukasz Wesolowski, Aapo Kyrola, Andrew Tulloch,Yangqing Jia, and Kaiming He. Accurate, large minibatchsgd: Training imagenet in 1 hour.arXiv:1706.02677, 2017.5

[14] Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, BoZhang, Dongdong Chen, Lu Yuan, and Baining Guo. Vec-tor quantized diffusion model for text-to-image synthesis. InCVPR, pages 10696–10706, 2022. 2

[15] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.Deep residual learning for image recognition. InCVPR,2016. 2

[16] Dan Hendrycks and Kevin Gimpel. Gaussian error linearunits (gelus).arXiv preprint arXiv:1606.08415, 2016. 12

[17] Tom Henighan, Jared Kaplan, Mor Katz, Mark Chen,Christopher Hesse, Jacob Jackson, Heewoo Jun, Tom BBrown, Prafulla Dhariwal, Scott Gray, et al. Scaling lawsfor autoregressive generative modeling.arXiv preprintarXiv:2010.14701, 2020. 2

[18] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner,Bernhard Nessler, and Sepp Hochreiter. Gans trained by atwo time-scale update rule converge to a local nash equilib-rium. 2017. 6

[19] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffu-sion probabilistic models. InNeurIPS, 2020. 2, 3

[20] Jonathan Ho, Chitwan Saharia, William Chan, David JFleet, Mohammad Norouzi, and Tim Salimans.Cas-caded diffusion models for high fidelity image generation.arXiv:2106.15282, 2021. 3, 9

[21] Jonathan Ho and Tim Salimans. Classifier-free diffusionguidance. InNeurIPS 2021 Workshop on Deep GenerativeModels and Downstream Applications, 2021. 3, 4

[22] Aapo Hyv ̈arinen and Peter Dayan.Estimation of non-normalized statistical models by score matching.Journalof Machine Learning Research, 6(4), 2005. 3

[23] Michael Janner, Qiyang Li, and Sergey Levine. Offline rein-forcement learning as one big sequence modeling problem.InNeurIPS, 2021. 2

[24] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom BBrown, Benjamin Chess, Rewon Child, Scott Gray, AlecRadford, Jeffrey Wu, and Dario Amodei. Scaling laws forneural language models.arXiv:2001.08361, 2020. 2

[25] Tero Karras, Miika Aittala, Timo Aila, and Samuli Laine.Elucidating the design space of diffusion-based generativemodels. InProc. NeurIPS, 2022. 3

[26] Tero Karras, Samuli Laine, and Timo Aila. A style-basedgenerator architecture for generative adversarial networks. InCVPR, 2019. 5

[27] Diederik Kingma and Jimmy Ba. Adam: A method forstochastic optimization. InICLR, 2015. 3, 5, 6

[28] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton.Imagenet classification with deep convolutional neural net-works. InNeurIPS, 2012. 5

[29] Tuomas Kynk ̈a ̈anniemi, Tero Karras, Samuli Laine, JaakkoLehtinen, and Timo Aila. Improved precision and recall met-ric for assessing generative models. InNeurIPS, 2019. 6

[30] Ilya Loshchilov and Frank Hutter. Decoupled weight decayregularization.arXiv:1711.05101, 2017. 5

[31] Charlie Nash, Jacob Menick, Sander Dieleman, and Peter WBattaglia. Generating images with sparse representations.arXiv preprint arXiv:2103.03841, 2021. 6

[32] Alex Nichol, Prafulla Dhariwal, Aditya Ramesh, PranavShyam, Pamela Mishkin, Bob McGrew, Ilya Sutskever,and Mark Chen.Glide: Towards photorealistic imagegeneration and editing with text-guided diffusion models.arXiv:2112.10741, 2021. 3, 4

[33] Alexander Quinn Nichol and Prafulla Dhariwal. Improveddenoising diffusion probabilistic models. InICML, 2021. 3

[34] Gaurav Parmar, Richard Zhang, and Jun-Yan Zhu.Onaliased resizing and surprising subtleties in gan evaluation.InCVPR, 2022. 6

[35] Niki Parmar, Ashish Vaswani, Jakob Uszkoreit, LukaszKaiser, Noam Shazeer, Alexander Ku, and Dustin Tran. Im-age transformer. InInternational conference on machinelearning, pages 4055–4064. PMLR, 2018. 2

[36] William Peebles, Ilija Radosavovic, Tim Brooks, AlexeiEfros, and Jitendra Malik. Learning to learn with genera-tive models of neural network checkpoints.arXiv preprintarXiv:2209.12892, 2022. 2

[37] Ethan Perez, Florian Strub, Harm De Vries, Vincent Du-moulin, and Aaron Courville. Film: Visual reasoning with ageneral conditioning layer. InAAAI, 2018. 2, 5

[38] Alec Radford, Jong Wook Kim, Chris Hallacy, AdityaRamesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry,Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learn-ing transferable visual models from natural language super-vision. InICML, 2021. 2

[39] Alec Radford, Karthik Narasimhan, Tim Salimans, and IlyaSutskever. Improving language understanding by generativepre-training. 2018. 1

[40] Alec Radford, Jeffrey Wu, Rewon Child, David Luan, DarioAmodei, Ilya Sutskever, et al. Language models are unsu-pervised multitask learners. 2019. 1

[41] Ilija Radosavovic, Justin Johnson, Saining Xie, Wan-Yen Lo,and Piotr Doll ́ar. On network design spaces for visual recog-nition. InICCV, 2019. 3

[42] Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick,Kaiming He, and Piotr Doll ́ar. Designing network designspaces. InCVPR, 2020. 3

[43] Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu,and Mark Chen. Hierarchical text-conditional image gener-ation with clip latents.arXiv:2204.06125, 2022. 1, 2, 3, 4

[44] Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray,Chelsea Voss, Alec Radford, Mark Chen, and Ilya Sutskever.Zero-shot text-to-image generation. InICML, 2021. 1, 2

[45] Robin Rombach, Andreas Blattmann, Dominik Lorenz,Patrick Esser, and Bj ̈orn Ommer. High-resolution image syn-thesis with latent diffusion models. InCVPR, 2022. 2, 3, 4,6, 9

[46] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmen-tation. InInternational Conference on Medical image com-puting and computer-assisted intervention, pages 234–241.Springer, 2015. 2, 3

[47] Chitwan Saharia, William Chan, Saurabh Saxena, LalaLi, Jay Whang, Emily Denton, Seyed Kamyar SeyedGhasemipour, Burcu Karagol Ayan, S. Sara Mahdavi,Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho, David JFleet, and Mohammad Norouzi.Photorealistic text-to-image diffusion models with deep language understanding.arXiv:2205.11487, 2022. 3

[48] Tim Salimans, Ian Goodfellow, Wojciech Zaremba, VickiCheung, Alec Radford, Xi Chen, and Xi Chen. Improvedtechniques for training GANs. InNeurIPS, 2016. 6

[49] Tim Salimans, Andrej Karpathy, Xi Chen, and Diederik PKingma. PixelCNN++: Improving the pixelcnn with dis-cretized logistic mixture likelihood and other modifications.arXiv preprint arXiv:1701.05517, 2017. 2

[50] Axel Sauer, Katja Schwarz, and Andreas Geiger. Stylegan-xl: Scaling stylegan to large diverse datasets. InSIGGRAPH,2022. 9

[51] Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan,and Surya Ganguli.Deep unsupervised learning usingnonequilibrium thermodynamics. InICML, 2015. 3

[52] Jiaming Song, Chenlin Meng, and Stefano Ermon. Denois-ing diffusion implicit models.arXiv:2010.02502, 2020. 3

[53] Yang Song and Stefano Ermon. Generative modeling by es-timating gradients of the data distribution. InNeurIPS, 2019.3

[54] Andreas Steiner, Alexander Kolesnikov, Xiaohua Zhai, RossWightman, Jakob Uszkoreit, and Lucas Beyer. How to trainyour ViT? data, augmentation, and regularization in visiontransformers.TMLR, 2022. 6

[55] Aaron Van den Oord, Nal Kalchbrenner, Lasse Espeholt,Oriol Vinyals, Alex Graves, et al. Conditional image genera-tion with pixelcnn decoders.Advances in neural informationprocessing systems, 29, 2016. 2

[56] Aaron Van Den Oord, Oriol Vinyals, et al. Neural discreterepresentation learning.Advances in neural information pro-cessing systems, 30, 2017. 2

[57] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszko-reit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and IlliaPolosukhin. Attention is all you need. InNeurIPS, 2017. 1,2, 5

[58] Tete Xiao, Piotr Dollar, Mannat Singh, Eric Mintun, TrevorDarrell, and Ross Girshick. Early convolutions help trans-formers see better. InNeurIPS, 2021. 6

[59] Jiahui Yu, Yuanzhong Xu, Jing Yu Koh, Thang Luong,Gunjan Baid, Zirui Wang, Vijay Vasudevan, Alexander Ku,Yinfei Yang, Burcu Karagol Ayan, et al. Scaling autore-gressive models for content-rich text-to-image generation.arXiv:2206.10789, 2022. 2

[60] Xiaohua Zhai, Alexander Kolesnikov, Neil Houlsby, and Lu-cas Beyer. Scaling vision transformers. InCVPR, 2022. 2,5

Diffusionモデル学習記録③ ―Score-based Generative Model and Guidance

Preface

 このシリーズでは、Diffusionモデルについて学習する時にノート代わりに記事を書いていく。これはその第二弾で以下の第一弾の続き。手始めにめちゃ分かりやすいと巷で話題の(そして実際分かりやすかった)以下のDiffusionモデルの解説論文を少しずつ翻訳していき、脳に焼き付けていく。後々より詳しい解説とか、自分でJAXで実装とかができたらいいなと思っている。

izmyon.hatenablog.com

arxiv.org

Calvin Luo: Understanding Diffusion Models: A Unified Perspective, arXiv: 2208:11970, doi: 10.48550/ARXIV.2208.11970

©Calvin Luo, Originally posted in arXiv(https://arxiv.org/abs/2208.11970), 25 Aug 2022

License: Creative Commons Attribution 4.0 International (CC-BY)

以下は、原文の一部を翻訳したもので、以下の図はそこから引用したものです。

The following is the translation of part of the original content and the figures below are retrieved from it.

Understanding Diffusion Models: A Unified Perspective

Variational Diffusion Model

Learning Diffusion Noise Parameters

ここで、VDMのノイズパラメータをどのように同時学習させるかを検討しよう。一つの方法として、パラメータ \etaを持つニューラルネットワーク \hat{\alpha}_{\eta} (t)を用いて \alpha_tをモデル化することが考えられる。しかし、これは \hat{\alpha}_t を計算するたびに推論を複数回行わなければならず、非効率的である。キャッシュはこの計算コストを軽減することができるが、我々は拡散ノイズのパラメータを学習する別の方法を導き出すことができる。式(85)の分散方程式を式(99)のタイムステップごとの目的関数に代入することで、削減することができる。

式(70)から、 q(x_t | x_0) \mathcal{N} \left( x_t; \sqrt{\alpha_t} x_0, (1- \alpha_t ) \boldsymbol{\rm{I}} \right) 形式のガウスであることを想起してほしい。そして、信号対雑音比(SNR)の定義である \rm{SNR}=\frac{ μ^ 2 }{ \sigma^ 2}に従って、各タイムステップにおけるSNRを以下のように書ける。

  \rm{SNR} (t) = \frac{\bar{\alpha}_t}{1- \bar{\alpha}_t} 
\tag{109}

そうすると、私たちが導き出した式(108)(および式(99))は、次のように簡略化できる。

  \frac{1}{2 \sigma_q^2 (t) } \frac{ \bar{\alpha}_{t-1} \left(1- \alpha_t \right)^2 }{ \left( 1 - \bar{\alpha}_t \right)^2 } \left[ \left|| \hat{x}_{\theta} (x_t, t) - x_0 \right||_2^2 \right] = \frac{1}{2} \left( \rm{SNR} (t-1) - \rm{SNR} (t) \right) \left[ \left|| \hat{x}_{\theta} (x_t, t) - x_0 \right||_2^2 \right]
\tag{110}

その名が示す通り、SNRは、元の信号と含有されるノイズ量の比率を表している。SNRが高いほど元の信号が多く、SNR が低いほどノイズが多いことを意味する。拡散モデルでは、SNRは時間と共に単調に減少させる必要がある。これは、摂動入力 x_tは時間とともにノイズが多くなり、 t=Tで標準ガウスと同じになるという概念を正式に示すものである。

式(110)の目的関数を単純化すると、ニューラルネットワークを用いて各タイムステップのSNRを直接パラメータ化し、拡散モデルと同時に学習させることができる。SNRは時間と共に単調に減少しなければならないので、次のように表すことができる。

  \rm{SNR} (t) = exp(-\omega_{\eta} (t) ) 
\tag{111}

ここで \omega_{\eta} (t)はパラメータ \etaを持つ単調増加ニューラルネットワークとしてモデル化されている。 \omega_{\eta} (t)を判定にすると単調減少する関数になるが、指数関数では結果として生じる項が正になるように強制される。式(100)の目的関数は \etaについても最適化しなければならないことに注意してほしい。式(111)のSNRのパラメータ化と式(109)のSNRの定義を組み合わせることにより、 \alpha_tの値と 1- \bar{α}_tの値についてエレガントな形を明示的に導き出すことも可能である。

  \begin{align}
\frac { \bar{\alpha_t} }{1- \bar{\alpha_t}} &= \rm{exp} \left( - \omega_{\eta} (t) \right) \tag{112} \\
\therefore  \bar{\alpha_t} &= \rm{sigmoid} \left( - \omega_{\eta} (t) \right)  \tag{113} \\
\therefore 1- \bar{\alpha_t} &= \rm{sigmoid} \left( \omega_{\eta} (t) \right) \tag{114}
\end{align}

これらの項は様々な計算に必要であり、例えば最適化の際には、式(69)で導かれるように、再パラメータ化トリックを用いて、入力 x_0から任意のノイズの多い x_tを作成するために使用される。

Three Equivalent Interpretations

先に証明したように、変分拡散モデルの学習は、単に任意のノイズ処理された画像 x_tとその時刻 tから、元の自然画像 x_0を予測するニューラルネットワークを学習するだけで可能である。しかし、 x_0は他に2つの等価なパラメタリゼーションを持つので、VDMはさらに2つの解釈をすることができる。

まず、再パラメータ化のトリックを用いる。 q_(x_t|x_0)の形の導出において、式(69)を並べ替えて、次のように示すことができる。

  x_0 = \frac{x_t - \sqrt{1- \bar{\alpha}_t}\epsilon}{\sqrt{ \bar{\alpha}_t}} 
\tag{115}

これを先に導いた真のノイズ除去遷移における平均 μ_q(x_t,x_0)に代入すると、次のように再導出できる。

したがって、近似的なノイズ除去遷移の平均 μ_{\theta} (x_t,t)を次のように設定することができる。

  \mu_{\theta} (x_t, t) = \frac{1}{ \sqrt{\alpha_t} } x_t - \frac{ 1-\alpha_t }{ \sqrt{1- \bar{\alpha}_t} \sqrt{\alpha_t} } \hat{\epsilon}_{\theta} (x_t, t)
\tag{125}

そうすると、対応する最適化問題は次のようになる。

ここで、 \hat{\epsilon}_{\theta} (x_t, t)  x_0から x_tを決定するソース雑音 \boldsymbol{\epsilon}_0 \sim \mathcal{N} \left( \boldsymbol{ \epsilon ; 0, \rm{I}} \right) を予測するように学習するニューラルネットワークである。したがって、元画像 x_0を予測してVDMを学習することは、ノイズを予測するように学習することと同等であることを示したが、経験的には、ノイズを予測した方が性能が良いという研究もある[5, 7]。

変分拡散モデルの三つ目の共通解釈を導出するため、我々はTweedieの公式[8]に訴える。Tweedieの公式は、指数関数型分布の真の平均は、その分布から得られた標本が与えられたとき、標本の最尤推定値(別名:経験平均)に推定値のスコアを含む何らかの補正項を加えたもので推定できることを述べている。観測された標本が1つだけの場合、経験平均は標本そのものである。もし、観測された標本がすべて基本分布の一方の端にある場合、負のスコアが大きくなり、標本のナイーブな最尤推定値を真の平均に向かって修正する。これは、標本の偏りを軽減するためによく使われる。

数学的には、ガウス変数 z \sim \mathcal{N} \left( z; \mu_ z, {\sum}_ z \right) に対して、Tweedieの公式は次のようになる。

  \mathbb{E} \left[ \mu_z | z \right] = z + {\sum}_z \nabla_z \log p (z)

この場合、その標本が与えられたときの x_tの真の事後平均を予測するために適用する。式(70)から、次のことがわかる。

  q (x_t | x_0) = N \left( x_t; \sqrt{\bar{\alpha}_t} x_0, (1- \bar{\alpha}_t) \boldsymbol{\rm{I}} \right)

すると、Tweedieの式により、次のようになる。

   \mathbb{E} \left[ \mu_z | x_t \right] = x_t + (1- \bar{\alpha}_t) \nabla_{x_t} \log p( x_t )
\tag{131}

ここで、表記を簡単にするために、 \nabla_{x_t} \log p (x_t)  \nabla \log p(x_t)と書く。Tweedieの式によれば、 x_tが生成される真の平均の最良推定値 \mu_{x_t}= \sqrt{\alpha_t} x_0は、次式で定義される。

  \begin{align}
\sqrt{\bar{\alpha}_t} x_0 = x_t + \left( 1- \bar{\alpha}_t \right) \nabla \log p(x_t) \tag{132} \\
\therefore x_0 = \frac{ x_t + \left( 1- \bar{\alpha}_t \right) \nabla \log p(x_t)}{\sqrt{\bar{\alpha}_t}} \tag{133}
\end{align}

次に、式(133)を、ground-truthのノイズ除去遷移の平均 μ_q (x_t, x_0)にもう一度代入すると、新しい形が得られる。

したがって、近似的なノイズ除去遷移の平均 \mu_{\theta} (x_t, t)を次のように設定することもできる。

  \mu_{\theta} (x_t, t) = \frac{1}{ \sqrt{\alpha_t} } x_t + \frac{ 1- \alpha_t }{ \sqrt{\alpha_t} } s_{\theta} (x_t, t)
\tag{143}

すると、対応する最適化問題は次のようになる。

ここで、 s_{\theta} (x_t, t)は任意のノイズレベル tに対して、データ空間における x_tの勾配であるスコア関数 \nabla_{x_t} \log p(x_t)を予測するように学習するニューラルネットである。

鋭い読者なら、スコア関数 \nabla \log p(x_t)がソース雑音 \epsilon_0と非常によく似た形をしていることに気づくだろう。このことは、Tweedieの公式(式(133))とパラメータ化のトリック(式(115))を組み合わせることで、明示的に示すことができる。

  \begin{align}
x_0 = \frac{x_t + (1- \bar{\alpha}_t) \nabla \log p(x_t) }{ \sqrt{ \bar{\alpha}_t } } &= \frac{x_t - \sqrt{1- \bar{\alpha}_t} \epsilon_0}{ \sqrt{ \bar{\alpha}_t } } \tag{149} \\ 
\therefore  (1- \bar{\alpha}_t) \nabla \log p(x_t) &= - \sqrt{1- \bar{\alpha}_t} \epsilon_0 \tag{150} \\
 \nabla \log p(x_t) &= - \frac{1}{\sqrt{1- \bar{\alpha}_t}} \epsilon_0 \tag{151}
\end{align}

その結果、2つの項は時間と共に変化する定数だけずれていることがわかる!スコア関数は、対数確率を最大化するようにデータ空間をどのように移動すればよいかを測るのである。直感的には、ノイズは自然な画像に付加され、画像を劣化させるので、その反対方向に移動することが画像を「ノイズ除去」し、その後の対数確率を増加させるための最善の更新となる。 我々の数学的証明はこの直感を正当化するもので、スコア関数をモデル化する学習は(スケーリングファクターまで)ソースノイズの負のモデルを作ることと等価であることを明示的に示している。

したがって我々は、VDMを最適化するための3つの等価な目的を導き出した。それは、元の画像 x_0、ソース雑音 \epsilon_0、または任意のノイズレベル \nabla \log p(x_t)における画像のスコアを予測するニューラルネットワークを学習することであった。VDMは、確率的に時間ステップ tをサンプリングし、予測値とground-truth目標値との差を最小化することで、スケーラブルに学習することができる。

Score-based Generative Models

我々は、単にスコア関数 \nabla \log p(x_t) を予測するニューラルネットワーク s_{\theta} (x_t, t)を最適化することによって、変分拡散モデルが学習できることを示した。しかし、この導出ではスコア項はTweedieの公式を応用したものであり、スコア関数とは何か、なぜそれがモデル化する価値があるのかについて、必ずしも大きな直感や洞察を与えてくれるものではない。幸いなことに、この直感を得るために、別のクラスの生成モデルであるスコアベース生成モデル[9, 10, 11] に注目することができる。その結果、我々が以前に導出したVDMの定式化は、同等のスコアベース生成モデリングの定式化を持つことを示すことができ、この2つの解釈を自在に切り替えることができるようになったのである。

図6:ランジュバン動力学で生成された三つのランダムなサンプリング軌道の視覚化。すべてのサンプルは混合ガウシアンに対して、同じ初期化点から出発している。左図はこれらのサンプリング軌道を三次元コンター上にプロットしたものであり、右図はサンプリング軌道をground-truthのスコア関数に対してプロットしたものである。 同じ初期化点から、異なるモードのサンプルを生成できるのは、ランジュバン動力学のサンプリング手順に含まれる確率的ノイズ項によるものである。これがなければ、固定点からのサンプリングは常に各試行ごとに同じモードへのスコアを決定論的に追従することになる。

なぜスコア関数を最適化することが意味を持つのかを理解するために、回り道をしてエネルギーベースのモデル[12, 13]を再検討してみることにする。任意に柔軟な確率分布は次のような形で書くことができる。

  p_{\theta} (x) = \frac{1}{Z_{\theta}} e^{- f_{\theta} (x)}
\tag{152}

ここで、 f_{\theta} (x) はエネルギー関数と呼ばれる任意に柔軟でパラメータ化可能な関数で、しばしばニューラルネットワークによってモデル化され、 Z_{\theta}  \int p_{\theta} (x) dx = 1を保証するための正規化定数である。このような分布を学習する一つの方法は最尤法であるが、これは正規化定数 Z_{\theta} = \int e^{ - f_{\theta} (x)}を扱いやすく計算する必要があり、これは複雑な f_{\theta} (x) を持つ関数では不可能かもしれない。

正規化定数を計算またはモデル化するのを避ける一つの方法は、ニューラルネットワーク s_{\theta} (x)を、分布 p(x)のスコア関数 \nabla \log p(x)を学習するために代わりに使用することである。これは、式(152)の両辺のlogの導関数を取ると、得られるという観察によって動機づけられている。

  \begin{align}
\nabla_x \log p_{\theta} (x) &= \nabla_x \log \left( \frac{1}{Z_{\theta}} e^{- f_{\theta} (x)} \right)  \tag{153} \\
&= \nabla_x \log \frac{1}{Z_{\theta}} + \nabla_x \log e^{- f_{\theta} (x)} \tag{154} \\
&= - \nabla_x f_{\theta} (x) \tag{155} \\
&\approx s_{\theta} (x) \tag{156}
\end{align}

これは、正規化定数を必要とせず、自由にニューラルネットワークとして表現することができる。また、スコアモデルは、Fisherダイバージェンスをground-truthのスコア関数と比較して最小化することにより最適化することができる。

  \mathbb{E}_{p(x)} \left[ \left|| s_{\theta} (x) - \nabla \log p(x) \right||^2_2 \right]
\tag{157}

スコア関数は何を表しているだろうか?すべての xについて、 xに関する対数尤度の勾配をとることで、尤度をさらに高めるためにデータ空間のどの方向に移動すべきかを本質的に記述する。直感的には、スコア関数はデータ xが存在する空間全体に対してベクトル場を定義しており、そのベクトルはモードを指示している。視覚的には、図6の右図のようになる。そして、真のデータ分布のスコア関数を学習することで、同じ空間内の任意の点から出発し、モードに達するまでスコアを繰り返し追いながらサンプルを生成することができる。このサンプリング方法はランジュバン動力学として知られており、数学的には次のように記述される。

 x_{i+1} \leftarrow x_i + c \nabla \log p(x_i) + \sqrt{2c} \epsilon, \quad i=0, 1, \ldots , K
\tag{158}

ここで、 x_0は事前分布(一様分布など)からランダムにサンプリングされ、 \epsilon \sim \mathcal{N} (\boldsymbol{\epsilon ;0,I})は生成されたサンプルが常にあるモードに収束せず、その近傍を推移して多様性を確保するための追加的なノイズ項である。さらに、学習されたスコア関数は決定論的であるため、ノイズ項を導入したサンプリングは生成過程に確率性を与え、決定論的な軌道を回避することができる。これは、複数のモードの間にある位置からサンプリングを初期化する場合に特に有効である。図6に、ランジュバン動力学のサンプリングとノイズ項の効果を視覚的に表現したものを示す。

式(157)の目的は、ground-truthスコア関数へのアクセスに依存しており、これは自然画像をモデル化するような複雑な分布では利用できないことに注意してほしい。幸い、スコアマッチング[14, 15, 16, 17]として知られる代替技術は、ground-truthのスコアを知らなくてもこのFisherダイバージェンスを最小化するように導かれ、確率的勾配降下で最適化することができる。

まとめて、分布をスコア関数として表すことを学び、それを使ってランジュバン動力学などのMCMCの技術によりサンプルを生成することは、スコアベースの生成モデリング[9, 10, 11] として知られている。

バニラスコアマッチング(vanilla score matching)には、Song&Ermon[9]が詳述しているように、3つの主要な問題がある。まず、スコア関数は高次元空間の低次元多様体に適用された場合、定義があいまいである。これは数学的に見ることができ、低次元多様体上にない点はすべて確率0となり、その対数は定義されない。これは、アンビエント空間全体の低次元多様体上にあることが知られている自然画像に対する生成モデルを学習しようとする場合に特に不都合である。

第二に、バニラスコアマッチングによって学習され推定されたスコア関数は、低密度領域では正確でない。このことは、式(157)で最小化する目的関数から明らかである。それは p(x)上の期待値であり、明示的にそこからのサンプルで訓練されるので、モデルはほとんど見られないか未見のサンプルに対して正確な学習シグナルを受け取らないであろう。これは、我々のサンプリング戦略では、高次元空間のランダムな位置(ランダムノイズである可能性が高い)から出発し、学習されたスコア関数に従って移動することになるので、問題である。ノイズの多い、あるいは不正確なスコア推定に従うため、最終的に生成されるサンプルも最適でない可能性があり、正確な出力に収束するためにさらに多くのイタレーションを必要とします。

最後に、ランジュバン動力学サンプリングは、たとえground-truthスコアを用いたとしても、混合しない可能性がある。真のデータ分布が2つの不連続な分布の混合物であるとする。

  p(x) = c_1 p_1(x) + c_2 p_2(x)
\tag{159}

次に、スコアが計算されるとき、log演算が分布から係数を分割し、勾配演算がそれをゼロにするので、これらの混合係数は失われる。これを視覚化するために、図6右に示すground-truthスコア関数は、3つの分布間の異なる重みに関係ないことに注意してほしい。描かれた初期化点からサンプリングするランジュバン動力学は、実際の混合ガウス分布では右下のモードが高い重みを持っているにもかかわらず、それぞれのモードに到達するチャンスがほぼ同じである。

この3つの欠点は、データに複数レベルのガウスノイズを加えることで、同時に解決できることがわかった。第一に、ガウスノイズ分布のサポートは空間全体であるため、摂動されたデータサンプルはもはや低次元の多様体に限定されることはない。次に、大きなガウスノイズを加えることで、各モードがデータ分布に占める面積が大きくなり、低密度の領域でより多くの学習信号が追加されます。最後に、分散を大きくした複数レベルのガウスノイズを加えることで、ground-truthの混合係数に対する中間的な分布になる。

形式的には、ノイズレベル \left\{ \sigma_t \right\}_{t=1}^Tの正の列を選び、漸次的に摂動されたデータ分布の系列を定義することができる。

  p_{\sigma_t} (x_t) = \int p(x) \mathcal{N} (x_t; x, \sigma_t^2 \boldsymbol{\rm{I}} ) dx
\tag{160}

そして、スコアマッチングを用いて、すべてのノイズレベルに対して同時にスコア関数を学習するニューラルネットワーク s_{\theta}(x,t)を学習する。

  \arg \min_{\theta} \sum_{t=1}^T \lambda (t) \mathbb{E}_{p_{\sigma_t} (x_t)} \left[ \left|| s_{\theta} (x, t) - \nabla \log p_{\sigma_t}(x_t) \right||^2_2 \right]
\tag{161}

ここで λ(t)はノイズレベル tを条件とする正の重み付け関数である。この目的関数は、変分拡散モデルを学習するために式(148)で導かれた目的関数とほぼ一致することに注意されたい。さらに、著者らは生成手続きとして、アニールされたランジュバン動力学サンプリングを提案している。これは、各 T,T-1, \ldots ,2,1について順にランジュバン動力学を実行することによりサンプルを生成するものである。初期化は固定された事前分布(例えば一様分布)から選択され、続く各サンプリングステップは前回のシミュレーションの最終サンプルから開始されます。ノイズレベルは時間ステップ tの間に着実に減少し、時間ステップのサイズを小さくしていくので、サンプルは最終的に真のモードに収束していきます。これは、変分拡散モデルのマルコフ型HVAEの解釈で行われるサンプリング手順(ランダムに初期化されたデータベクトルが、減少するノイズレベル上で反復的に改良される)に直接類似している。

したがって、変分拡散モデルとスコアベース生成モデルの間には、その学習目的とサンプリング手順の両方において、明確な関連が確立されている。

一つは、拡散モデルを無限のタイムステップに自然に一般化するにはどうしたらいいかという問題である。マルコフ型HVAEでは、階層数を無限大 T→∞に拡張すると解釈できる。このことは、同等のスコアベース生成モデルの観点から表現するとより明確である。無限大のノイズスケールの下では、連続時間における画像の摂動は確率過程として表現でき、したがって確率微分方程式(SDE)で記述できる。サンプリングはSDEを反転することで行われ、当然ながら各連続値ノイズレベルにおけるスコア関数を推定する必要がある[10]。SDEの異なるパラメタリゼーションは本質的に異なる摂動スキームを時間経過とともに記述し、ノイズ処理手順の柔軟なモデリングを可能にする[6]。

Guidance

これまで、我々はデータ分布 p(x)のみをモデル化することに焦点をあててきた。しかし、条件付き分布 p(x|y)の学習にもしばしば関心があり、条件付け情報 yによって生成されるデータを明示的に制御することができるようになる。 これは、カスケード拡散モデル(Cascaded Diffusion Models)[18]などの画像超解像モデルや、DALL-E 2[19] やImagen[7] などの最先端の画像テキストモデルのバックボーンを形成する。

条件情報を追加する自然な方法は、各反復においてタイムステップ情報を並べるだけである。式(32)の同時分布を思い出してほしい。

  p(x_{0:T}) = p(x_T) \prod_{t=1}^T p_{\theta} (x_{t-1} | x_t )

そして、これを条件付き拡散モデルにするには、各遷移ステップで任意の条件付け情報 yを追加するだけでよく、次のようになる。

  p(x_{0:T} | y) = p(x_T) \prod_{t=1}^T p_{\theta} (x_{t-1} | x_t , y )
\tag{162}

例えば、 yは画像テキスト生成におけるテキスト符号化であったり、超解像処理を行うための低解像度画像であったりする。このように、VDMのコアとなるニューラルネットワークは、従来通り、各解釈・実装に対して、 \hat{x}_{\theta} (x_t, t, y) \approx x_0, \hat{\epsilon}_ {\theta} (x_t, t, y) \approx \epsilon_ 0, \  \rm{or} \ s_{\theta} (x_t, t, y) \approx \nabla \log p(x_t | y) と予測して学習することが可能である。

このバニラ定式化の注意点は、この方法で訓練された条件付き拡散モデルは、与えられた条件付け情報を無視したり、軽視したりするようになる可能性があるということである。そこで、サンプルの多様性を犠牲にして,モデルが条件付け情報に与える重みの量をより明示的に制御する方法として,誘導(Guidance)が提案されています。最も一般的な誘導は,分類器誘導(Classifier Guidance) [10, 20] と分類器放棄誘導(Classifier-Free Guidance) [21] の2つである。

Classifier Guidance

ここで、我々の目標は、任意のノイズレベル tにおいて、条件付きモデルのスコアである \nabla \log p(x_t|y)を学習することであるとする。ここで、 \nablaは簡潔さのために \nabla_{x_t}の省略形であることを想起してほしい。ベイズの定理により、以下の等価形式を導出することができる。

 \begin{align}
\nabla \log p(x_t | y) &= \nabla \log \left( \frac{p(x_t)p(y|x_t)}{p(y)} \right) \tag{163} \\
&= \nabla \log p(x_t) + \nabla \log p(y | x_t) - \nabla \log p(y) \tag{164} \\ 
&= \nabla \log p(x_t) + \nabla \log p(y|x_t) \tag{165}
\end{align}

ここで、 \log p(y) x_tに対する勾配が0であることを利用している。

最終的に得られた結果は、条件付けられないスコア関数と分類器 p(y|x_t)の逆勾配を組み合わせた学習と解釈することができる。したがって、分類器誘導[10, 20]では、任意のノイズ x_tを取り込み、条件付き情報 yを予測しようとする分類器とともに、先に導いたように条件付けられない拡散モデルのスコアが学習されることになる。そして、サンプリング処理中に、アニールされたランジュバン動力学に使用されるオーバーオール条件付きスコア関数が、条件づけられないスコア関数とノイズの多い分類器の逆勾配の和として計算される。

条件付き情報を考慮することをモデルに奨励または抑制するような、きめの細かいコントロールを導入するために、分類器誘導はハイパーパラメータ項 γでノイジーな分類器の逆勾配をスケールする。分類器誘導のもとで学習されたスコア関数は以下のようにまとめられる。

  \nabla \log p(x_t | y) = \nabla \log p(x_t) + \gamma \nabla \log p(y|x_t)
\tag{166}

直感的には、 γ=0のとき、条件付き拡散モデルは条件付け情報を完全に無視することを学習し、 γが大きいとき、条件付き拡散モデルは条件付け情報に大きく依存するサンプルを生成するように学習する。この場合、ノイズが多い場合でも、与えられた条件情報を再生成しやすいデータしか生成しないため、サンプルの多様性が犠牲になる。

分類器誘導の欠点として、別途学習した分類器に依存することが挙げられます。分類器は任意にノイズの多い入力を扱わなければならず、既存の事前学習済み分類モデルのほとんどは最適化されていないため、拡散モデルと並行してアドホックに学習させる必要があります。

Classifier-Free Guidance

分類器放棄誘導(Classifier-Free Guidance)[21]では、条件付けられない拡散モデルと条件付き拡散モデルを用いて別の分類器モデルを学習することを放棄している。分類器放棄誘導におけるスコア関数を導出するために、まず、式(165)を変形し、以下のように示すことができる。

  \nabla \log p(y|x_t) = \nabla \log p(x_t | y) - \nabla \log p(x_t) 
\tag{167}

そして、これを式(166)に代入すると、次のようになる。

  \begin{align}
\nabla \log p(x_t | y) &= \nabla \log p (x_t) + \gamma \left( \nabla \log p(x_t | y) - \nabla \log p(x_t) \right) \tag{168} \\
&= \nabla \log p(x_t) + \gamma \log p(x_t | y) - \gamma \nabla \log p(x_t) \tag{169} \\
&= \gamma \nabla \log p (x_t | y ) + (1- \gamma) \nabla \log p (x_t) \tag{170}
\end{align}

繰り返しになるが、 γは学習した条件付きモデルが条件付け情報をどの程度気にするかを制御する項である。 γ= 0のとき、学習済み条件付きモデルは条件付けを完全に無視し、条件付けられない拡散モデルを学習する。 γ= 1のとき、モデルは明示的にバニラ条件分布を学習し、誘導は行わない。また、 γ>1のとき、拡散モデルは条件付きスコア関数を優先させるだけでなく、無条件スコア関数から離れる方向に動く。つまり、条件情報を用いないサンプルの生成確率を下げ、条件情報を明示的に用いるサンプルを優先させる。これはまた、条件付け情報に正確に一致するサンプルを生成する代償として、サンプルの多様性を減少させる効果がある。

2つの別々の拡散モデルを学習することは高価であるので、我々は条件付き拡散モデルと条件付けられない拡散モデルの両方を特異条件付きモデルとして一緒に学習することが可能である。条件付け情報をゼロなどの固定定数に置き換えることで、条件づけられない拡散モデルを照会することができる。これは、本質的に、条件付け情報に対してランダムドロップアウトを実行することである。分類器放棄誘導は、条件付き生成の手順をより細かく制御できる一方で、特異な拡散モデルの訓練以上のものを必要としないため、エレガントである。

Closing

ここで、我々の研究過程で得られた知見を再掲する。まず、マルコフ型階層変分オートエンコーダの特殊なケースとして変分拡散モデルを導出し、3つの重要な仮定によりELBOの計算とスケーラブルな最適化が可能であることを説明する。VDMの最適化は、三つの潜在的な目的関数の一つを予測するニューラルネットワークの学習に帰着することを証明する。その三つとは、任意のノイズ処理からの元のソース画像、任意のノイズ処理画像からの元のソースノイズ、任意のノイズレベルにおけるノイズ処理画像のスコア関数、のいずれかである。次に、スコア関数を学習することの意味を深く掘り下げ、スコアベース生成モデリングの観点と明示的に結びつける。最後に、拡散モデルを用いた条件付き分布の学習方法について説明する。

要約すると、拡散モデルは生成モデルとして驚くべき能力を示しており、実際、ImagenやDALL-E 2などのテキスト条件付き画像生成に関する現在の最先端モデルで威力を発揮している。さらに、これらのモデルを可能にする数学は、非常にエレガントである。しかし、まだいくつかの欠点が残っている。

  • この手法は、私たち人間が、自然にデータをモデリングし、生成する方法とは思えない。我々は、何度もノイズを除去するランダムノイズとしてサンプルを生成しない。

  • VDMは解釈可能な潜在能力を生成しない。VAEではエンコーダの最適化により構造化潜在空間を学習するが、VDMでは各タイムス テップのエンコーダは既に線形ガウスモデルとして与えられており、柔軟に最適化するこ とはできない。したがって、中間潜在変数は、元の入力のノイズの多いバージョンに過ぎないものに制限される。

  • 潜在変数は元入力と同じ次元に制限され、意味のある圧縮された潜在量構造を学習する努力をさらに怠らせることになる。

  • サンプリングは、両方の定式化の下で複数のノイズ除去ステップを実行しなければならないため、高価な手順である。仮定の1つに、最終的な潜在が完全にガウスノイズであることを保証するために、十分な数のタイムステップ Tを選択することを思いだして欲しい。サンプリングの際には、これらのタイムステップをすべて繰り返し、サンプルを生成する必要がある。

最後に、拡散モデルの成功は、生成モデルとしての階層VAEを強調するものである。我々は、エンコーダが些細で、潜在次元が固定され、マルコフ遷移が仮定されている場合でも、無限の潜在階層に一般化すると、データの強力なモデルを学習することができることを明らかにした。このことは、複雑なエンコーダと意味的に意味のある潜在空間を学習できる可能性のある、一般的な深いHVAEの場合に、さらなる性能向上が達成できることを示唆している。

謝辞:この研究のドラフトを見直し、多くの有益な編集とコメントを提供してくれたJosh Dillon、Yang Song、Durk Kingma、Ben Poole、Jonathan Ho、Yiding Jiang、Ting Chen、Jeremy CohenそしてChen Sunに感謝したい。本当にありがとう!

Reference

[7] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar SeyedGhasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding.arXivpreprintarXiv:2205.11487, 2022.

[8] Bradley Efron. Tweedie’s formula and selection bias. Journal of the American Statistical Association, 106(496):1602–1614, 2011.

[9] Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. Advances in Neural Information Processing Systems, 32, 2019.

[10] Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456, 2020.

[11] Yang Song and Stefano Ermon. Improved techniques for training score-based generative models. Advances in neural information processing systems, 33:12438–12448, 2020.

[12] Yann LeCun, Sumit Chopra, Raia Hadsell, M Ranzato, and F Huang. A tutorial on energy-based learning. Predicting structured data, 1(0), 2006.

[13] Yang Song and Diederik P Kingma. How to train your energy-based models. arXiv preprint arXiv:2101.03288,2021.

[14] Aapo Hyvärinen and Peter Dayan. Estimation of non-normalized statistical models by score matching. Journal of Machine Learning Research, 6(4), 2005.

[15] Saeed Saremi, Arash Mehrjou, Bernhard Schölkopf, and Aapo Hyvärinen. Deep energy estimator networks. arXiv preprint arXiv:1805.08306, 2018.

[16] Yang Song, Sahaj Garg, Jiaxin Shi, and Stefano Ermon. Sliced score matching: A scalable approach to density and score estimation. In Uncertainty in Artificial Intelligence, pages 574–584. PMLR, 2020.

[17] Pascal Vincent. A connection between score matching and denoising autoencoders. Neural computation, 23(7):1661–1674, 2011.

[18] Jonathan Ho, Chitwan Saharia, William Chan, David J Fleet, Mohammad Norouzi, and Tim Salimans. Cascaded diffusion models for high fidelity image generation. J.Mach.Learn.Res., 23:47–1, 2022.

[19] Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, and Mark Chen. Hierarchical text-conditional image generation with clip latents. arXiv preprint arXiv:2204.06125, 2022.

[20] Prafulla Dhariwal and Alexander Nichol. Diffusion models beat gans on image synthesis. Advances in Neural Information Processing Systems, 34:8780–8794, 2021.

[21] Jonathan Ho and Tim Salimans. Classifier-free diffusion guidance. In NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications, 2021.

Diffusionモデル学習記録② ―Variational Diffusion Model

Preface

 このシリーズでは、Diffusionモデルについて学習する時にノート代わりに記事を書いていく。これはその第二弾で以下の第一弾の続き。手始めにめちゃ分かりやすいと巷で話題の(そして実際分かりやすかった)以下のDiffusionモデルの解説論文を少しずつ翻訳していき、脳に焼き付けていく。後々より詳しい解説とか、自分でJAXで実装とかができたらいいなと思っている。

izmyon.hatenablog.com

arxiv.org

Calvin Luo: Understanding Diffusion Models: A Unified Perspective, arXiv: 2208:11970, doi: 10.48550/ARXIV.2208.11970

©Calvin Luo, Originally posted in arXiv(https://arxiv.org/abs/2208.11970), 25 Aug 2022

License: Creative Commons Attribution 4.0 International (CC-BY)

以下は、原文の一部を翻訳したもので、以下の図はそこから引用したものです。

The following is the translation of part of the original content and the figures below are retrieved from it.

Understanding Diffusion Models: A Unified Perspective

Variational Diffusion Model

 変分拡散モデル (VDM: Variational Diffusion Model) [4, 5, 6] を考える最も簡単な方法は、単純に以下の三つの重要な制約を持つマルコフ型階層変分オートエンコーダとして考えることである。

  • 潜在次元はデータ次元と正確に等しい。

  • 各タイムステップにおける隠れエンコーダの構造は学習されず、線形ガウスモデルとして予め定義されている。言い換えれば、それは前のタイムステップの出力を中心とするガウス分布である。

  • 隠れエンコーダのガウスパラメータは、最終タイムステップTにおける潜在変数 x_Tの分布が標準ガウス分布になるように時間的に変化する。

さらに、標準的なマルコフ型階層変分オートエンコーダにおいける階層的遷移間のマルコフ特性は、明示的に維持するものとする。 これらの仮定が意味するところを拡大してみよう。最初の仮定から、多少の表記法の乱用はあるが、真のデータサンプルと潜在変数の両方を x_tで表すものとする。ここで、 t = 0は真のデータサンプルを表し、 t∈[1,T]はインデックス tの階層を持つ対応する潜在変数を表す。VDMの事後分布は、MHVAEの事後分布(式(24))と同じだが、今は次のように書き換えることができる。

 \displaystyle q(x_{1:T} | x_0 ) = \prod_{t=1}^T q( x_t | x_{t-1} ) \tag{30}

第二の仮定から、エンコーダの各潜在変数の分布は、その前のHVAEの出力を中心とするガウス分布であることがわかる。マルコフ型HVAEとは異なり、各タイムステップ tにおけるエンコーダの構造は学習されない。それは線形ガウスモデルとして固定され、平均と標準偏差はハイパーパラメータとしてあらかじめ設定されるか[5]、またはパラメータとして学習される[6]。ここでは、平均 \boldsymbol{μ_ t}  (x_ t) = \sqrt{ α_ t } x_ {t-1} 、分散 \sum_t (x_t) = (1 - α_t) \boldsymbol{\rm{I}}のガウシアンエンコーダとしてパラメータ化する。ここで係数の形式は,潜在変数の分散が同じようなスケールにとどまるように選択される。言い換えれば、このエンコーディングの手順は”分散保存的(variance-preserving)”である。ここで、ガウス分布のパラメータ化は他の方法も可能であり、同様の導出ができることに注意。このやり方では、 α_tは(潜在的に学習可能な)係数であり、階層的深さ tに応じて柔軟に変化することである。数学的には、エンコーダの遷移は次のように表される。

 \displaystyle q(x_t | x_{t-1} ) = \mathcal{N}(x_t; \sqrt{α_t} x_{t-1}, (1 - α_t) \boldsymbol{\rm{I}} ) \tag{31}

第三の仮定から、 α_tは固定された、あるいは学習可能なスケジュールに従って時間と共に進化し、最終的な潜在変数 p(x_T)の分布が標準ガウスとなることがわかる。そして、マルコフ型HVAEの同時分布(式(23))を更新して、VDMの同時分布を次のように書くことができる。

 \displaystyle \begin{align} 
p(x_{0:T}) &= p(x_T) \prod_{t=1}^T p_{\theta}( x_{t-1} | x_t ) \tag{32} \\
where, \\
p(x_T) &= \mathcal{N}(x_T ; \boldsymbol{0, \rm{I} } ) \tag{33}
\end{align}

つまり、これらの制約が意図しているのは、時間経過とともに入力画像を徐々にノイズに変えていくということであり、より詳しく言うと、ガウスノイズを加えていくことで入力画像を徐々に崩壊させ、最終的には純粋なガウスノイズと完全に同じにさせるということである。この過程を視覚的に表現したのが図3である。

図3:変分拡散モデルの視覚的表現。 x_0は自然画像などの真のデータ観測、 x_Tは純粋なガウスノイズ、 x_t x_0にノイズを加えた x_Tとの中間的な状態の画像である。各 q(x_t | x_{t-1})は直前の状態の出力を平均とするガウス分布としてモデル化される。

なお、エンコーダの分布 q(x_t | x_{t-1})は、各タイムステップにおいて、定義済みの平均と分散のパラメータを持つガウシアンとして完全にモデル化されるため、もはや φによってパラメータ化されないことに注意する。したがって、VDMでは、新しいデータをシミュレートできるように、条件 p_{\theta} (x_{t-1} | x_t ) を学習することにのみ興味がある。VDM を最適化した後のサンプリング手順は、ガウシアンノイズを p(x_T)からサンプリングし、繰り返しノイズ除去遷移 p_{\theta} (x_{t-1} | x_t) をTステップ通し、新しい x_0を生成する。

他のHVAEと同様に、VDMはELBOを最大化することで最適化でき、次のように導出される。

ELBOの導出形式は、その個々の構成要素で解釈することができる。

  1.  \mathbb{E}_{q(x_0 | x_1)} \left[ \log p_{\theta} ( x_0 | x_1 ) \right] は再構成項と解釈され、一段階目の潜在変数が与えられたときの元のデータサンプルの対数尤度を予測する。この項はバニラVAEにも現れ、同様に学習させることができる。
  2.  \mathbb{E}_{q(x_{T-1} | x_0 )} \left[ D_{KL} (q(x_T | x_{T-1} ) || p(x_T) ) \right] は事前マッチング項であり、最終的な潜在分布がガウス事前分布と一致するときに最小化される。この項は、学習可能なパラメータを持たないため、最適化の必要がない。さらに、最終的な分布がガウス分布となるように、十分に大きな Tを仮定しているため、この項は実質的にゼロになる。
  3.  \mathbb{E}_{q \left( x_{t-1}, x_{t+1} | x_0 \right) } \left[ D_{KL} \left( q(x_t | x_{t-1} ) || p_{\theta}(x_t | x_{t+1} ) \right) \right] は整合項であり、 x_tの分布が前方、後方の両プロセスで整合するように努める。つまり、各中間タイムステップごとに、ノイズの多い画像からのノイズ除去ステップは、対応するよりクリーンな画像からのノイズ加算ステップと一致すべきであり、これはKLダイバージェンスによって数学的に反映される。この項は、式(31)で定義されているように、 p_{\theta} (x_t | x_{t+1})ガウス分布 q(x_t| x_{t-1})と一致するように学習されるとき最小となる。

ELBOのこの解釈は、図4に視覚的に描かれている。すべてのタイムステップ tに対して最適化しなければならないため、VDMを最適化するコストは、主に第3項によって支配される。

図4:最初の導出のもとで、VDMは、各潜在変数 x_tについて、その後の潜在変数からの事後分布 p_{\theta} (x_t | x_{t+1})が、その前の潜在変数からのガウシアンによる劣化 q( x_t |x_{t-1} )と一致するように最適化することができる。この図では、各潜在変数 x_tについて、ピンクと緑の矢印で示される分布の差を最小化している。

この導出の下では、ELBOのすべての項は期待値として計算され、したがって、モンテカルロ推定を使用して近似することができる。しかし、今導出した項を用いてELBOを実際に最適化すると、最適とは言えないかもしれない。整合項は、すべてのタイムステップごとに2つの確率変数 \left\{ x_{t-1}, x_{t+1} \right\} に対する期待値として計算されるので、モンテカルロ推定の分散は、タイムステップごとに1つの確率変数のみを用いて推定される項よりも大きくなる可能性がある。 T-1個の整合項の合計で計算されるため、最終的なELBOの推定値は Tの値が大きいと分散が大きくなる可能性がある。

その代わりに、各項が一度に1つの確率変数に対する期待値として計算されるELBOの形式を導出することを試みましょう。重要なのは、エンコーダの遷移を q(x_t | x_{t-1}) = q(x_t | x_{t-1}, x_0)と書き換えることで、マルコフ特性により、余分な条件項は不要になることである。そして、ベイズ則に従って、各遷移を次のように書き換えることができる。

 \displaystyle 
q\left( x_t \mid x_{t-1}, x_{0} \right) = \frac{ q \left( x_{t-1} \mid x_{t}, x_{0}\right) q\left( x_{t} \mid x_{0} \right) }{q \left( x_{t-1} \mid x_{0} \right)} \tag{46}

この新しい式を使って、式(37)のELBOから再開して導出を試みることができる。

以上より、低い分散で推定できるELBOの解釈を導き出すことに成功し、各項が一度に最大でも一つの確率変数の期待値として計算されることがわかる。この定式化は、個々の項を調べることでエレガントに解釈することができる。

  1.  \mathbb{E}_{q(x_1 | x_0)} \left[ \log p_{\theta} \left( x_0 | x_1 \right) \right] は再構成項と解釈できる。バニラ VAEのELBO における類似の項と同様、この項はモンテカルロ推定を用いて近似および最適化できる。
  2.  D_{KL} \left( q \left( x_T | x_0 \right) ‖ p ( x_T)  \right) はノイズ化された入力の最後の分布が標準ガウス事前分布にどれくらい近いかを示す。学習可能なパラメータは無く、ここでの仮定の下では0に等しい。

  3.  \mathbb{E}_{q \left( x_t | x_0 \right) } \left[ D_{KL} ( q \left( x_{t-1} | x_t, x_0 \right) \right] は"ノイズ除去マッチング項"である。望みのノイズ除去遷移ステップ p_{\theta} ( x_{t-1} | x_t ) を、ground-truthのノイズ除去遷移ステップ q \left( x_{t-1} | x_t, x_0 \right) の扱いやすい近似として学習する。 q \left( x_{t-1} | x_t, x_0 \right) はノイズの多い画像 x_tをどのようにノイズ除去するかを定義し、最終的に完全にノイズ除去された画像 x_0がどうあるべきかを知っているため、ground-truth信号として機能することができる。したがって、この項は、KLダイバージェンスによって測定されているように、2つのノイズ除去ステップができるだけ一致するときに最小化される。

余談だが、二つのELBOの導出過程(式(45)および式(58))において、マルコフ仮定のみが用いられており、その結果、これらの式は任意のマルコフ型HVAEに対して成り立つことがわかる。さらに、 T=1とすると、二つのVDMのELBOの解釈は、いずれも式(19)で書かれるように、バニラVAEのELBO方程式を正確に再現する。

このELBOの導出では、最適化コストの大部分が再び総和項にあり、再構成項に対して支配的であった。各KLダイバージェンスの項 D_{KL} \left( q \left( x_{t-1} \mid x_t, x_0 \right) ‖ p_{\theta} \left( x_{t-1}|x_t \right) \right) は、エンコーダを同時に学習するという複雑さが加わるため、任意に複雑なマルコフ型HVAEでは任意の事後分布に対して最小化することが難しいが、VDMではガウス推移仮定を利用して最適化を扱いやすくすることが可能である。ベイズの定理により、以下のようになる。

 \displaystyle 
q \left( x_{t-1} \mid x_{t}, x_{0} \right) = \frac{ q \left( x_{t} \mid x_{t-1}, x_{0} \right) q \left( x_{t-1} \mid x_{0} \right) } { q \left( x_{t} \mid x_{0} \right)}

エンコーダの遷移に関する仮定(式(31))から、 q( x_t | x_{t-1}, x_0 ) = q( x_t | x_{t-1} ) = N \left( x_t ; \sqrt{α_t} x_{t-1}, (1-α_t) \boldsymbol{\rm{I}} \right) が既に分かっているので、後は q( x_t | x_0 ) q(x_{t-1} | x_0 )の形を導けば良い。幸運なことに、VDMのエンコーダ遷移は直線ガウスモデルであるという事実を利用して、これも扱いやすくすることができる。再パラメータ化トリックの下で、サンプル x_t \sim q ( x_t | x_{t-1} ) は以下のように書き換えることができる。

 \displaystyle  x_t = \sqrt{α_t} x_{t-1} + \sqrt{1 - α_t} \epsilon \quad with \quad \epsilon \sim \mathcal{N}(\boldsymbol{\epsilon ; 0, \rm{I}} ) \tag{59}

同様に、サンプル x_{t-1} \sim q \left( x_{t-1} | x_{t-2} \right) は次のように書き換えることができることがわかる。

 \displaystyle  x_t = \sqrt{α_{t-1}} x_{t-2} + \sqrt{1 - α_{t-1}} \epsilon \quad with \quad \epsilon \sim \mathcal{N}(\boldsymbol{\epsilon ; 0, \rm{I}} ) \tag{60}

図5:VDMを小さい分散で最適化するための代替方法を図示した。ベイズの定理を用いてground-truthのノイズ除去ステップ q(x_{t-1} | x_t, x_0 )を計算し、そのKLダイバージェンスを近似ノイズ除去ステップ p_{\theta} (x_{t-1} | x_t )で最小化する。ここでは再び、マッチングさせる分布を緑色の矢印とピンク色の矢印で視覚的に表している。ここで、全体像を正確に書くと、各ピンクの矢印は、条件付けの項でもあるため、 x_0からも派生しているはずであるが、ここでは割愛した。

そして、 q( x_t | x_0 )の形式は再パラメータ化のトリックを繰り返し適用することで再帰的に導出することができる。ここで、 2Tのランダムノイズ変数 \left\{ \epsilon_t^{*}, \epsilon_{t=0} \right\}_{t=0}^T \sim i.i.d \quad \mathcal{N} \left( \boldsymbol{\epsilon ; 0, \rm{I}} \right) を取得できたとする。そして、任意のサンプル x_t  \sim q( x_t | x_0)に対して、以下のように書き換えることができる。

ここで、式(64)では2つの独立なガウス確率変数の和がいまだガウシアンであり、平均は2つの平均の和、分散は2つの分散の和であることを利用している。 \sqrt{1- α_t} \epsilon_{t-1}^{*}をガウシアン \mathcal{N} \left( \boldsymbol{0}, (1-α_t) \boldsymbol{\rm{I}} \right) からのサンプル、そして \sqrt{ α_t - α_t α_{t-1}} \epsilon_{t-2}^{*}をガウシアン \mathcal{N} \left( \boldsymbol{0}, (α_t - α_t α_{t-1} ) \boldsymbol{\rm{I}} \right)の標本とすると、それらの和はガウシアン \mathcal{N}\left( \boldsymbol{0}, (1-α_t + α_t - α_t α_{t-1}) \boldsymbol{\rm{I}} \right) =\mathcal{N} \left( \boldsymbol{0}, (1-α_t α_{t-1}) \boldsymbol{\rm{I}} \right) からサンプリングした確率変数として扱うことができる。この分布からのサンプルは、再パラメータ化のトリックを使って、 \sqrt{1-α_t α_{t-1}} \epsilon_{t-2}として、式(66)のように表現される。

したがって、ガウシアン形式 q(x_t | x_0 )を導出しました。この導出は q(x_{t-1} | x_0) を記述するガウス分布のパラメータを得るために修正することができる。ここで、 q(x_t | x_0 ) q(x_{t-1} | x_0 )の両方の形式を知っているので、ベイズ則展開に代入して q \left( x_{t-1} | x_t, x_0 \right) の形式の計算に進むことができる。

ここで式(75)の C(x_t, x_0) x_t x_0 αの値のみの組み合わせとして計算された x_{t-1}のそれぞれに関する定数項であり、この項は式(84)に暗黙的に返されて平方完成される。

したがって、各ステップにおいて x_{t-1}  \sim q( x_{t-1} | x_t, x_0) x_t x_0の関数である平均 μ_q (x_t, x_0)と係数 αの関数である分散 \sum_q (t) 正規分布することが示された。 これらの係数 αは既知で各タイムステップで固定されており、ハイパーパラメータとしてモデル化された場合は恒常的に固定化されるか、またはモデル化しようとするネットワークの現在の推論出力として扱われる。式(84)に従うと、分散方程式を {\sum}_q (t) = {\sigma}_q^{2} (t)と書き直すことができ、以下が成り立つ。

 \displaystyle 
 \sigma_q^2 (t) = \frac {(1-α_t)(1-α_{t-1})}{1-α_t} \tag{85}

近似ノイズ除去遷移ステップ p_{\theta} ( x_{t-1} | x_t )をground-truthのノイズ除去遷移ステップ q_(x_{t-1} | x_t, x_0 )にできるだけ近づけるために、ガウシアンとしてモデル化することもできる。さらに、すべてのα項は各タイムステップで凍結されることが知られているため、近似されたノイズ除去遷移ステップの分散も \sum_ q (t) =  \sigma_ q ^2 (t) \boldsymbol{\rm{I}}となるように直ちに構築することができる。しかし、 p_{\theta} (x_{t-1} |x_t ) x_0を条件としないので、その平均 μ_{\theta} ( x_t, t ) x_tの関数としてパラメータ化しなければならない。

ここで、2つのガウス分布の間のKL収束は次の通りであることを思い出してほしい。

 \displaystyle 
D_{KL} \left( \mathcal{N} (x; μ_x, {\sum}_x ) || \mathcal{N} ( y; μ_y, {\sum}_y ) \right) = \frac{1}{2} \left[ \log \frac{|{\sum}_y|}{|{\sum}_x|} - d + tr({\sum}_y^{-1} {\sum}_x ) + (μ_y - μ_x ) ^T {\sum}_y^{-1} ( μ_y - μ_x ) \right] \tag{86}

この場合、2つのガウス分布の分散を正確に一致させることができるので、KLダイバージェンス項の最適化は、2つの分布の平均の差を最小化させるために減少させることになる。

ここで、 μ_q μ_q(x_t, x_0 )の略記、 μ_{\theta}  μ_{\theta} (x_t, t) の略記として簡略化して書いている。言い換えれば、我々は、 μ_q(x_t,x_0)に一致する μ_{\theta}(x_t,t)を最適化したいのであって、我々の導き出した式(84)から、次のような形をとる。

 \displaystyle
\boldsymbol{μ}_{q} (x_t, x_0) = \frac{ \sqrt{ α_t }(1- \hat{α_{t-1}} ) x_t + \sqrt{ \hat{α_{t-1}} } (1-α_t)x_0}{1- \hat{α_{t}} } \tag{93}

 μ_{\theta} (x_t, t)  x_tを条件としているので、以下の形に設定することで μ_q(x_t,x_0)に近い形で一致させることができる。

 \displaystyle
\boldsymbol{μ}_{\theta} (x_t, t) = \frac{\sqrt{α_t}(1- \hat{α_{t-1}} ) x_t + \sqrt{\hat{α_{t-1}}} (1-α_t) \hat{x_{\theta}}(x_t, t) }{1- \hat{α_{t}} } \tag{94}

ここで、 \hat{x_{\theta}} (x_t, t) は、ノイズの多い画像 x_tと時間インデックス tから x_0を予測しようとするニューラルネットワークによってパラメータ化される。そして、最適化問題は次のように単純化される。

したがって、VDMを最適化することは、任意にノイズ化された画像から元のground-truthとなる画像を予測するニューラルネットワークの学習に帰結する[5]。さらに、すべてのノイズレベルにわたって、我々の導き出したELBO目的関数(式(58))の総和項を最小化することは、すべてのタイムステップにわたって、この式を最小化することによって近似できる。

 \displaystyle
\arg \min_{\theta} \mathbb{E}_{t \sim U \left[ 2, T \right] } \left[ \mathbb{E}_{q(x_t | x_0 )} \left[ D_{KL} \left( q ( x_{t-1} | x_t, x_0 ) || p_{\theta} ( x_{t-1} | x_t ) \right) \right] \right] \tag{100}

これは次に、時間ステップにわたる確率的なサンプルを使うことで最適化される。

Reference

[4] ] Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In International Conference on Machine Learning, pages 2256–2265.PMLR, 2015.

[5] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33:6840–6851, 2020.

[6] Diederik Kingma, Tim Salimans, Ben Poole, and Jonathan Ho. Variational diffusion models. Advances in neural information processing systems, 34:21696–21707, 2021.

Diffusionモデル学習記録① ―Variational Autoencoder

Preface

 このシリーズでは、Diffusionモデルについて学習する時にノート代わりに記事を書いていく。手始めにめちゃ分かりやすいと巷で話題の(そして実際分かりやすかった)Diffusionモデルの解説論文を少しずつ翻訳していき、脳に焼き付けていく。後々より詳しい解説とか、自分でJAXで実装とかができたらいいなと思っている。

arxiv.org

Calvin Luo: Understanding Diffusion Models: A Unified Perspective, arXiv: 2208:11970, doi: 10.48550/ARXIV.2208.11970

©Calvin Luo, Originally posted in arXiv(https://arxiv.org/abs/2208.11970), 25 Aug 2022

License: Creative Commons Attribution 4.0 International (CC-BY)

以下は、原文の一部を翻訳したもので、以下の図はそこから引用したものです。

The following is the translation of part of the original content and the figures below are retrieved from it.

Understanding Diffusion Models: A Unified Perspective

Introduction: Generative Models

 注目する分布から観測されたサンプル xが与えられたとき、生成モデルの目標は、その真のデータ分布 p(x)をモデルを学習することである。一度学習すれば、近似モデルから新しいサンプルを自由に生成することができる。さらに、ある定式化の下では、学習したモデルを用いて、観測データやサンプリングされたデータの尤度を評価することも可能である。

 ここでは、現在の文献でよく知られているいくつかの方向性を、高いレベルで簡単に紹介する。敵対的生成ネットワーク (GAN)は、複雑な分布のサンプリング手順をモデル化し、敵対的な方法で学習させる。"尤度ベース"と呼ばれる生成モデルの別のクラスは、観測されたデータサンプルに高い尤度を割り当てるモデルを学習しようとするものである。これには自己回帰モデル、正規化フロー、変分オートエンコーダ(VAE)などがある。また、類似の手法としてエネルギーベースモデリングがあり、これは分布を任意に柔軟なエネルギー関数として学習し、それを正規化するものである。

 スコアベース生成モデルは、エネルギー関数そのものをモデル化するのではなく、エネルギーベースモデルの"スコア"をニューラルネットワークとして学習するもので、非常に関連性の高いモデルである。この研究では、尤度ベースとスコアベースの両方の解釈を持つ拡散モデル(Diffusion Model)について検討し、レビューする。 このようなモデルの背後にある数学は、誰もが拡散モデルとは何か、どのように機能するかを理解できることを目的として、非常に詳細に紹介する。

Background: ELBO, VAE, and Hierarchical VAE

 多くのモダリティでは、観測されたデータは、関連する見えない潜在変数(確率変数 zで表すことができる)によって表される、あるいは生成されると考えることができる。この考えを最も良く直観的に表現するのは、プラトンの洞窟の寓話である。この寓話では、ある人々が生涯洞窟の中に鎖でつながれ、目の前の壁に映し出される二次元の影しか見ることができず、それは火の前を通る目に見えない三次元の物体によって生成されるとされている。そのような人たちにとって、自分が見ているものはすべて、実は自分には決して見ることのできない高次元の抽象的な概念によって決定されているのである。

 同様に、私たちが現実の世界で出会う物体も、例えば、色や大きさ、形といった抽象的な性質が内包された高次元の表現(representation)の関数として生成されている可能性がある。 洞窟につながれた人が観測しているものが実は三次元の物体の二次元の投影であったように、私たちが観察しているものは、その抽象概念の三次元的な投影、あるいは実体、として解釈することができる。洞窟につながれた人たちは、隠されたものを決して見ることはできず、あるいは完全に理解することもできないが、それでも私たちが観測したデータを記述する潜在的表現を近似するのと同じ方法で、それについて推論を巡らせることはできる。

 プラトンの寓話は、潜在変数が観測を決定する潜在的に観測不可能な表現であるという考え方を示しているが、この類推の注意点は、生成モデルにおいては、一般的に高次元の潜在的表現よりも低次元の潜在的表現を学ぼうとすることである。これは、観測よりも高い次元の表現を学習しようとしても、強い事前分布を持たなければ実りのない努力になるからである。一方、低次元の潜在的表現の学習は圧縮の一形態とみなすこともでき、観測を記述する意味的(Semantically)に有意義な構造を発見することができる可能性がある。

Evidence Lower Bound

 数学的には、潜在変数と我々が観測するデータは、同時分布 p(x, z)によってモデル化されると想像することができる。”尤度ベース”と呼ばれる生成モデリングの1つのアプローチは、すべての観測された xの尤度 p(x)を最大化するモデルを学習することであることを思い出してほしい。純粋な観測データの尤度 p(x)を復元するために、この同時分布を操作する方法が2つある。まず、以下のように潜在変数 zを明示的に周辺化する方法がある。

 \displaystyle p(x) = \int p(x, z) dz \tag{1}

もう一つの方法は、確率の連鎖法則を用いて以下のように求めることである。

 \displaystyle p(x) = \frac {p(x,z)}{p(z|x)} \tag{2}

尤度 p(x)を直接計算し最大化することは、式(1)においては、すべての潜在変数 z積分する必要があり、複雑なモデルでは困難であり、また式(2)においても、ground-truthであるエンコーダ p(z|x)を求める必要があるため、困難である。しかし、この2つの式を用いると、その名の通りエビデンスの下界であるEvidence Lower Bound (ELBO)という項を導き出すことができる。エビデンスはこの場合、観測データの対数尤度として定量化される。そして、ELBOを最大化することは、潜在変数のモデルを最適化するための代理目的となる。ELBOが強力にチューニングされ、完全に最適化された最良の場合、それはエビデンスと厳密に等価となる。形式的には、ELBOは以下の式で表される。

 \displaystyle  \mathbb E_{q_{φ}(z|x)} \left[ \frac {p(x,z)}{q_{φ}(z|x)} \right]  \tag{3}

エビデンスとの関係を明示的に表したい場合は、以下のように書ける。

 \displaystyle \log p(x) \geq \mathbb E_{q_{φ}(z|x)} \left[ \frac {p(x,z)}{q_{φ}(z|x)} \right]  \tag{4}

ここで、 q_{φ}(z|x)は,パラメータ φを持つ柔軟な近似変分であり、我々が最適化しようとする分布である。直感的には,与えられた観測 xの潜在変数上の真の分布を推定するために学習されるパラメータを持つモデルとして考えることができる。次節で変分オートエンコーダについて探る際に見るように、ELBO を最大化するためにパラメータ  φをチューニングして下限を大きくすると、真のデータ分布をモデル化し、そこからサンプリングするために使用できる要素にアクセスできるようになり、生成モデルを学習することができる。ここで、なぜELBOが最大化したい目的関数であるのかについて、もう少し掘り下げて考える。 まず、式 (1)を用いて 、ELBO を導出する。

 \displaystyle \begin{align} 
\log p(x)
&= \log \int p(x, z) dz &(式(1)を用いた) \tag{5} \\
&= \log \int \frac{p(x, z) q_{φ}(z|x)}{ q_{φ}(z|x)}dz &(1=\frac {q_{φ}(z|x)}{ q_{φ}(z|x)}を掛けた) \tag{6} \\
&= \log \mathbb E_{q_{φ}(z|x)} \left[ \frac {p(x,z)}{q_{φ}(z|x)} \right] &(期待値の定義) \tag{7} \\
& \geq \mathbb E_{q_{φ}(z|x)} \left[ \log \frac {p(x,z)}{q_{φ}(z|x)} \right] &(Jensenの不等式を用いた) \tag{8}
\end{align}

この導出では、Jensenの不等式を適用することにより、直接的に下界に到達している。しかし、これは、実際に何が起こっているかについての有益な情報をあまり与えていない。決定的なのは、この証明が、ELBOが実際にエビデンスの下界であるのは何故か、という問いに対し、正確な直感を与えないことであり、Jensenの不等式がそれを遠ざけてしまうことである。さらに、ELBOが本当にデータの下界であることを単に知るだけでは、何故それを目的関数として最大化したいのかがよく分からない。エビデンスと ELBO の関係をよりよく理解するために、今度は式(2)を使って別の導出を行う。

 \displaystyle \begin{align} 
\log p(x)
&= \log p(x) \int q_{φ}(z|x) dz &( 1= \int q_{φ}(z|x) dz を掛けた) \tag{9} \\
&= \int q_{φ}(z|x) (\log p(x)) dz &(エビデンスを積分の中に入れた) \tag{10} \\
&= \mathbb E_{q_{φ}(z|x)} \left[ \log p(x) \right] &(期待値の定義) \tag{11} \\
&= \mathbb E_{q_{φ}(z|x)} \left[ \log \frac {p(x,z)}{p(z|x)} \right] &(式(2)を用いた) \tag{12} \\
&= \mathbb E_{q_{φ}(z|x)} \left[ \log \frac {p(x,z)q_{φ}(z|x)}{p(z|x)q_{φ}(z|x)} \right] &(1=\frac {q_{φ}(z|x)}{ q_{φ}(z|x)}を掛けた) \tag{13} \\
&= \mathbb E_{q_{φ}(z|x)} \left[ \log \frac {p(x,z)}{q_{φ}(z|x)} \right] + \mathbb E_{q_{φ}(z|x)} \left[ \log \frac {q_{φ}(z|x)}{p(z|x)} \right] &(期待値を分割した) \tag{14} \\
&= \mathbb E_{q_{φ}(z|x)} \left[ \log \frac {p(x,z)}{q_{φ}(z|x)} \right] + D_{\rm{KL}} ( q_{φ}(z|x) || p(z|x) ) &(\rm{KL}ダイバージェンスの定義) \tag{15} \\
& \geq \mathbb E_{q_{φ}(z|x)} \left[ \log \frac {p(x,z)}{q_{φ}(z|x)} \right] &(\rm{KL}ダイバージェンスは常に\geq 0) \tag{16}
\end{align}

この導出より、式(15)から、エビデンスがELBOに近似事後分布 q_{φ}(z|x)と真の事後分布 p(z|x)間のKLダイバージェンスを加えたものであることは明らかである。実は、最初の導出の式(8)でJensenの不等式によって魔法のように取り除かれていたのは、このKLダイバージェンスの項なのである。この項を理解することは、ELBOとエビデンスの関係だけでなく、ELBOを最適化することが全く適切な目的である理由を理解するための鍵となる。

まず、我々はここでELBOが本当に下界である理由がわかった。エビデンスとELBOの差は厳密に非負のKL項であり、したがってELBOの値は決してエビデンスを超えることができないのである。

図1:変分オートエンコーダを視覚的に表した。ここで、エンコーダ q(z|x)は観測 xに対する潜在変数 zの分布を表しており、 p(x|z)は潜在変数を観測にデコードする。

次に、何故ELBOを最大化しようとするのかを探る。モデル化したい潜在変数 zを導入した我々の目標は、観測データを記述するこの潜在的な構造を学習することである。言い換えれば、我々は変分事後分布 q_{φ}(z|x)のパラメータを最適化し、真の事後分布 p(z|x)に正確に一致させたいのであり、これはKLダイバージェンスを最小化(理想的には0に)することによって実現されるのである。残念ながら、我々はground-truthである p(z|x)の分布を求めることはできないため、このKLダイバージェンス項を直接最小化することは困難である。しかしながら、式(15)の左側で、我々のデータの尤度(したがって我々のエビデンス \log p(x))は、同時分布 p(x,z)から全ての潜在変数 zを周辺化して計算されるので、 φに対して常に定数であり、全く φに依存していないことに注目されたい。ELBO項とKLダイバージェンス項の和は定数になるので、φに関してELBO項を最大化すると、必然的にKLダイバージェンス項が等しく最小化されることになる。したがって、ELBOは真の潜在的な事後分布を完全にモデル化する方法を学習するために、代理として最大化する目的関数とすることができる。ELBOを最適化すればするほど、我々の近似的な事後分布は真の事後分布により近くなる。さらに、一度学習すると、モデルのエビデンス \log p(x)を近似するように学習するので、ELBOは観測データまたは生成データの尤度を推定するためにも使用することができる。

Variational Autoencoders

Variational Autoencoder (VAE) [1]のデフォルトの定式化では、ELBOを直接的に最大化します。このアプローチは変分法であり、 φによってパラメータ化された潜在的な事後分布族の中から最も良い q_{φ}(z|x)に最適化する。オートエンコーダと呼ばれるのは、入力データが中間的なボトルネック表現を経た後、入力データ自身を予測するように学習されるという、従来のオーエンコーダモデルを彷彿とさせるからである。この関係を明示するために、ELBO項をさらに分解してみる。

 \displaystyle \begin{align} 
\mathbb E_{q_{φ}(z|x)} \left[ \log \frac {p(x,z)}{q_{φ}(z|x)} \right]
&=\mathbb E_{q_{φ}(z|x)} \left[ \log \frac {p_θ(x|z)p(z)}{q_{φ}(z|x)} \right] \quad &(確率の連鎖率) \tag{17} \\
&=\mathbb E_{q_{φ}(z|x)} \left[ \log p_θ(x|z) \right] + \mathbb E_{q_{φ}(z|x)} \left[ \log \frac {p(z)}{q_{φ}(z|x)} \right] \quad &(期待値を分割) \tag{18} \\
&=\mathbb E_{q_{φ}(z|x)} \left[ \log p_θ(x|z) \right] - D_{\rm{KL}} ( q_{φ}(z|x) || p(z) ) \quad &(\rm{KL}ダイバージェンスの定義) \tag{19} 
\end{align}

この場合、我々は中間的なボトルネック分布 q_{φ}(z|x)はエンコーダとして扱うことができ、入力を潜在変数が取りうる値の分布に変換する。同時に、決定論的関数 p_{θ}(x|z)を学習し、与えられた潜在ベクトル  zを観測値 xに変換するが、これはデコーダと解釈される。 式(19)の2項はそれぞれ直感的に説明できる。第一項は我々の変分分布からデコーダの再構成された尤度を測定する。これは学習した分布が元のデータを再生成できる有効な潜在空間をモデル化していることを保証するものである。第二項は、学習された変分分布が潜在変数に関する事前信念にどれだけ似ているかを測定する。この項を最小化することで、エンコーダはディラックデルタ関数に陥ることなく、実際に分布を学習するようになる。したがって、ELBOを最大化することは、その第一項を最大化し、第二項を最小化することと等価である。VAEを定義づける特徴は、パラメータ φ θに対して同時にELBOを最適化する方法である。一般的には、VAEのエンコーダには対角共分散を持つ多変量ガウシアンをモデル化するもの、事前分布は標準的な多変量ガウシアンが用いられることが多い。

 \displaystyle \begin{align} 
\displaystyle q_{φ}(z|x) &= \mathcal{N}(z; \boldsymbol{μ}_φ(x), \boldsymbol{\sigma}_φ^2 (x) \boldsymbol {x} )\tag{20} \\
p(z) &= \mathcal{N}(z; \boldsymbol{0}, \boldsymbol{I} \tag{21})
\end{align}

そして、ELBOのKLダイバージェンス項は解析的に計算され、再構成項はモンテカルロ推定で近似できる。そして、我々の目的関数は次のように書き換えられる。

 \displaystyle \arg \max_{φ, θ} \mathbb E_{q_{φ}(z|x)} \left[ \log p_θ(x|z) \right] - D_{\rm{KL}} ( q_{φ}(z|x) || p(z) )  \simeq  \arg \max_{φ, θ} \sum_{l=1}^L log p_θ(x|z^{(l)}) -  D_{\rm{KL}} ( q_{φ}(z|x) || p(z) ) \tag{22}

ここで潜在変数 \left\{ z^{(l)} \right\} _{l=1}^{L} は、データセット中のすべての観測 xに対して q_{φ}(z|x)からサンプリングされる。しかし、このデフォルトの設定では問題が発生する。というのも、我々の損失が計算される各 z^{(l)}は確率的サンプリング手順で生成され、それは一般的に微分可能ではないのである。幸い、 q_{φ}(z|x)が多変量ガウス分布など、ある種の分布をモデル化するように設計されている場合、”再パラメータ化のトリック”によって対処することができる。

再パラメータ化のトリックとは、ランダムな変数をノイズ変数の決定論的関数に書き換えることである。これにより、勾配降下法により非確率項を最適化することを可能にする。例えば、任意の平均 μと分散 σ^2を持つ正規分布 x ∼N (x ; μ, σ^2) からのサンプルは以下のように書き換えることができる。

 \displaystyle x = μ + \sigma \epsilon \quad with \quad \epsilon \sim \mathcal{N}(\epsilon; 0, \boldsymbol{I})

つまり,任意のガウス分布は,その平均を0から目標の平均 μに加算シフトし,分散を目標の分散 σ^2だけ伸ばした標準ガウス分布(のサンプル)と解釈することができる。したがって、再パラメータ化のトリックにより、任意のガウス分布からのサンプリングは、標準ガウスからサンプリングし、その結果を目標の標準偏差でスケーリングし、目標の平均でシフトすることで行うことができる。

VAEでは、このように各 zは入力 xと補助雑音変数 \epsilon決定論的関数として計算される。

 \displaystyle \boldsymbol{z} =\boldsymbol{μ_φ} (x)+ \boldsymbol{\sigma_φ} (x) \odot \epsilon \quad with \quad \boldsymbol{\epsilon} \sim \boldsymbol{\mathcal{N}} (\boldsymbol{ \epsilon }; \boldsymbol{0, I})

ここで、 \odotは要素ごとの積を表す。このように再パラメータ化された zのもとで、 φに対して勾配を計算し、 μ_φ \sigma_{φ}を最適化する。それにより、VAEは再パラメータ化のトリックとモンテカルロ推定を利用し、 φ θに対してELBOを同時に最適化することができる。 VAEの学習後、潜在空間 p(z)から直接サンプリングし、デコーダを通すことで、新しいデータを生成することができる。変分オートエンコーダは、 zの次元が入力 xの次元よりも小さい場合に特に興味深いもので、コンパクトで有用な表現を学習することができる。さらに、意味的に(semantically)有用な潜在空間を学習した場合、潜在ベクトルをデコーダに渡す前に編集し、生成されるデータをより正確に制御することができる。

Hierarchical Variational Autoencoders

階層的オートエンコーダ(HVAE)[2, 3] は、潜在変数に対する複数個の階層に拡張されたVAE の一般化である。この定式化のもとでは、潜在変数それ自体は、他のより高いレベルの、より抽象的な潜在変数から生成されたものとして解釈される。直感的には、我々が三次元の観測対象をより抽象的な潜在表現から生成されたものとして扱うように、プラトンの洞窟の人々は三次元の対象を二次元の観測を生成する潜在表現として扱っている。したがって、プラトンの洞窟の住人から見れば、彼らの観測は深さ2(またはそれ以上)の潜在的階層によってモデル化されたものとして扱えられる。

図2: T個の隠れ階層を持つマルコフ階層変分オートエンコーダ。生成プロセスはマルコフ連鎖としてモデル化され、それぞれの潜在変数 z_tは前の潜在変数 z_{t+1}によってのみ条件付けられる。

 T個の階層を持つ一般的なHVAEでは、各潜在変数は以前の全ての潜在変数に条件付けすることができるが、本研究では、マルコフHVAE(MHVAE)と呼ぶ特別なケースに注目する。MHVAEでは、生成過程がマルコフ連鎖である。つまり、階層を下る各遷移はマルコフ連鎖であり、各潜在変数 z_{t}は前の潜在変数 z_{t+1}にのみ条件付けられる。直感的、視覚的には、これは図2に描かれているように、単にVAEを積み重ねたものと見ることができる。このモデルを表現する別の適切な用語は再帰的VAEである。数学的には、マルコフHVAEの同時分布と事後分布を次のように表現する。

 \displaystyle \begin{align} 
p(x, z_{1:T}) &= p(x_T)p_{\theta}(x|z_1) \prod_{t=2}^T p_{\theta}(z_{t-1} | z_t) \tag{23} \\
q_{φ}(z_{1:T}|x) &= q_{φ}(z_1|x) \prod_{t=2}^T q_{φ}(z_t | z_{t-1}) \tag{24}
\end{align}

そして以下のようにして、ELBOを簡単に拡張することができる。

 \displaystyle \begin{align} 
\log p(x)
&= \log  \int p( \boldsymbol{x}, z_{1:T}) dz_{1:T} &(式(1)を用いた。) \tag{25} \\
&= \log \int \frac{p( \boldsymbol{x}, z_{1:T}) q_{φ}(z_{1:T}|\boldsymbol{x}) }{ q_{φ}(z_{1:T}|\boldsymbol{x}) }dz &(1=\frac {q_{φ}(z_{1:T}|\boldsymbol{x})}{ q_{φ}(z_{1:T}|\boldsymbol{x})}を掛けた) \tag{26} \\
&= \log \mathbb E_{q_{φ}(z_{1:T}|\boldsymbol{x})} \left[ \frac{p( \boldsymbol{x}, z_{1:T})}{q_{φ}(z_{1:T}|\boldsymbol{x})} \right] &(期待値の定義) \tag{27} \\
& \geq \mathbb E_{q_{φ}(z_{1:T}|\boldsymbol{x})} \left[ \log  \frac{p( \boldsymbol{x}, z_{1:T})}{q_{φ}(z_{1:T}|\boldsymbol{x})} \right] &(Jensenの不等式を用いた) \tag{28}
\end{align}

次に、同時分布(式(23))と事後分布(式(24))を式(28)に代入すると、別の形になる。

 \displaystyle \mathbb E_{q_{φ}(z_{1:T}|\boldsymbol{x})} \left[ \log  \frac{p( \boldsymbol{x}, z_{1:T})}{q_{φ}(z_{1:T}|\boldsymbol{x})} \right] = \mathbb E_{q_{φ}(z_{1:T}|\boldsymbol{x})} \left[ \log  \frac{p(x_T)p_{\theta}(x|z_1) \prod_{t=2}^T p_{\theta}(z_{t-1} | z_t) }{q_{φ}(z_1|x) \prod_{t=2}^T q_{φ}(z_t | z_{t-1})} \right] \tag{29}

後述するように、変分拡散モデル(Variational Diffusion Model)を検討する場合、この目的関数はさらに解釈可能な要素に分解することが可能である。

Reference

[1] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint, arXiv:1312.6114, 2013.

[2] Durk P Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya Sutskever, and Max Welling. Improved variational inference with inverse autoregressive flow. Advances in neural information processing systems, 29, 2016.

[3] Casper Kaae Sønderby, Tapani Raiko, Lars Maaløe, Søren Kaae Sønderby, and Ole Winther. Ladder variational autoencoders. Advances in neural information processing systems, 29, 2016.

メンタルヘルスケアのための対話システム:レビュー論文のレビュー② ―開発のための技術および評価指標

Preface

 Alaa A. Abd-alrazaqらによる、メンタルヘルスのための対話システムに関する、一連のレビュー論文についてのまとめの第二弾。第一弾は、以下。

izmyon.hatenablog.com

それぞれのレビュー論文について、とりあえずAbstractとPrincipal findingsを訳し、それ以外の個人的に面白そうな知見はOther Interesting Findingsにまとめた。今回は、開発に用いた技術についてのレビューと、評価指標に関するレビューの二つについてまとめる。

Technical Aspects of Developing Chatbots for Medical Applications: Scoping Review

www.jmir.org

Safi Z, Abd-Alrazaq A, Khalifa M, Househ M Technical Aspects of Developing Chatbots for Medical Applications: Scoping Review J Med Internet Res 2020;22(12):e19127, doi: 10.2196/19127, PMID: 33337337, PMCID: 7775817

©Zeineb Safi, Alaa Abd-Alrazaq, Mohamed Khalifa, Mowafa Househ. Originally published in the Journal of Medical Internet Research (http://www.jmir.org), 18.12.2020.

License: Creative Commons Attribution 4.0 (CC-BY)

The following is the edited translation.

Abstract

Background

 チャットボットは、ユーザーと自然言語による対話を行うことができるアプリケーションである。医療分野では、さまざまな目的でチャットボットが開発・利用されている。患者にタイムリーな情報を提供するなど、メンタルヘルスの治療者へのアクセスなどといった文脈で重要な役割を果たしいる。1960年代後半に最初のチャットボットであるELIZAが開発されて以来、さまざまな健康目的のチャットボットをさまざまな方法で開発するための多くの努力が続いている。

Objective

 本研究は、最適な開発方法を説明し、チャットボット開発研究者の将来の研究開発を支援するため、医療分野で使用されるチャットボットに関する技術的側面と開発方法論を探求することを目的としたものである。

Methods

 8つの文献データベース(IEEEACM、Springer、ScienceDirect、Embase、MEDLINE、PsycINFO、Google Scholar)で関連論文を検索した。また、選択した論文の前方および後方参照チェックを行った。研究の選択は1人の査読者が行い、選択された研究の50%は2人目の査読者が無作為にチェックした。結果の統合にはナラティブアプローチを用いた。チャットボットは、開発における異なる技術的側面に基づいて分類された。各モジュールを実装するためのさまざまな技術に加え、主なチャットボットの構成要素が特定された。

Results

 最初の検索で2481件の論文が見つかり、その中から包含基準および除外基準に合致する45件の研究を特定した。ユーザーとチャットボットの間のコミュニケーションで最も一般的な言語は英語であった(n=23)。テキスト理解モジュール、対話管理モジュール、データベース層、テキスト生成モジュールの4つの主要なモジュールを特定した。テキスト理解と対話管理の開発手法で最も多いのはパターンマッチング法である(それぞれn=18、n=25)。テキスト生成は固定出力が最も多い(n=36)。オリジナルの出力生成に依存する研究は非常に少なかった。ほとんどの研究は、会話を通してチャットボットが異なる目的で使用するために、医療知識ベースを保持していた。いくつかの少数の研究は、対話履歴を保持し、ユーザーデータと以前の会話を収集した。

Conclusions

 多くのチャットボットが医療用として開発され、そのスピードは増している。最近、チャットボットのシステム開発機械学習ベースのアプローチを採用するシフトが見受けられる。臨床成果をさまざまなチャットボット開発技術や技術的特徴と関連付けるために、さらなる研究を行うことができる。

Principal Findings

 チャットボットの主要コンポーネントと、これらのコンポーネントの連携方法について報告する。チャットボットは通常、テキスト理解モジュール、対話管理モジュール、データ管理層、テキスト生成モジュールの4つの主要コンポーネントで構成されている。

 チャットボットの開発で採用されている最も一般的な設計手法は、テキスト理解と応答生成のためにパターンマッチングを用いることである。一方、機械学習や生成手法は、医療領域におけるチャットボットの開発で最も一般的に使用されていない手法の一つである。これは主に2つの理由に起因する。まず、機械学習に基づく手法よりもパターンマッチングの手法に依存する一つ目の理由は、パターンマッチングの手法は、明確に定義されたクエリに対して正確なレスポンスを生成し、結果としてミスが少なくなるため、より信頼性が高いからである。機械学習ベースの手法は、通常、さまざまな種類のエラーを発生させるため、医療アプリケーションでは許容されない。2つ目の理由は、機械学習分野の状況が急速に発展し、特に深層学習の出現によってその手法の頑健性が高まったのは、ここ数年のことだからである。古い手法ではルールベースのチャットボットやパターンマッチングアルゴリズムに依存していたが、テキスト理解や応答生成に機械学習を利用した手法はすべて2017年から2019年の間に提案されたものである。また、機械学習手法を使用していない可能性がある理由として、機械学習ベースのアプローチは、大量のドメイン固有のデータを使用して学習する必要があり、医療分野では不足していて入手が困難な可能性があるという事実が考えられる。全体的に、機械学習アプローチとアルゴリズムは、メンタルヘルス自閉症などの特定の医療状態に使用するチャットボットの開発により適しており、ルールベースのアプローチは、一般的な医療目的に使用するチャットボットの開発により適していることが分かった。一方、パターンマッチング手法やアルゴリズムは、特殊な医療と一般的な医療の両方に利用されるチャットボットの開発により広く利用されていた。

 データ管理の面では、開発したチャットボットは、医学的事実の辞書を含む医学知識ベース、ユーザーの属性や好みに関する詳細を含むユーザー情報データベース、ユーザーに応答する会話文の全エントリーを含む対話スクリプトデータベースの3種類のデータベースを記録していた。どのようなデータベースを保持するかは、チャットボットの種類とターゲットとする機能によって異なる。教育用チャットボットは通常、医療用知識データベースを保持する。ユーザーの感情に基づいてコンテキストを切り替えるチャットボットは、通常、ユーザー情報データベースを保持する。

 開発されたチャットボットの多くは、ユーザーとのコミュニケーション言語として英語を使用しており、ドイツ語、中国語、アラビア語などの他の言語はあまり見られなかった。これは、発表の多くが米国発であり、次いで英語を第一言語とするオーストラリア発であることと整合的である。

Other Interesting Findings

 動的な対話管理はより自然なユーザー体験を提供するにもかかわらず、開発されたシステムのほとんどは静的な対話管理手法に依存している。ユーザーの感情に基づいて、あるいはユーザー入力のトピックの変化を検出して対話のコンテキストを変更することは、チャットボットの開発において考慮すべき重要な点である。

 近年、さまざまな応用分野において、対話エージェントの開発に機械学習人工知能の手法を用いることが増えている。チャットボットの開発における機械学習ベースの手法の採用率は、近年増加傾向にあるとはいえ、まだ比較的低い水準にある。教師ありの機械学習アルゴリズムは、特殊な病状や疾患を対象としたチャットボットの開発に適しているようであり、ルールベースの手法は、一般的な医療目的で使用するチャットボットの開発に多く利用されているようである。機械学習の手法を用いることで、よりダイナミックで柔軟な対話の管理、幅広い応答の生成など、より優れたテキスト理解を実現し、よりリアルなユーザー体験を提供できるエージェントを開発することができる。

 よりオープンに、対話管理[4]、テキスト理解[5]、テキスト生成[6]の方法におけるstate-of-the-art手法のより幅広い適応を文献により公開することは、医療分野における会話型エージェントの開発に本当に有益である。

 注目すべきは、チャットボット開発の技術的な側面が、必ずしも研究において明確に言及されていないことである。考案されたアーキテクチャは一般的なものであり、必ずしもすべての開発されたチャットボットに適用されるわけではない。1つ以上のコンポーネントが省略されても、チャットボットは正常に機能する可能性がある。

Technical Metrics Used to Evaluate Health Care Chatbots: Scoping Review

www.jmir.org

Abd-Alrazaq A, Safi Z, Alajlani M, Warren J, Househ M, Denecke K Technical Metrics Used to Evaluate Health Care Chatbots: Scoping Review J Med Internet Res 2020;22(6):e18301, doi: 10.2196/18301, PMID: 32442157, PMCID: 7305563

©Alaa Abd-Alrazaq, Zeineb Safi, Mohannad Alajlani, Jim Warren, Mowafa Househ, Kerstin Denecke. Originally published in the Journal of Medical Internet Research (http://www.jmir.org), 05.06.2020.

License: Creative Commons Attribution 4.0 (CC-BY)

The following is the edited translation.

Abstract

Background

 対話エージェント(チャットボット)は、ヘルスケア分野での応用の歴史が長く、患者の自己管理支援やカウンセリングなどのタスクに利用されてきた。医療システムへの需要の高まりと人工知能(AI)能力の向上に伴い、その利用は拡大すると予想されている。しかし、ヘルスケア用チャットボットの評価に対するアプローチは多様かつ行き当たりばったりに見えるため、この分野の進歩の妨げになる可能性がある。

Objective

 本研究は、ヘルスケアチャットボットを評価するために先行研究が使用した技術的(非臨床的)な指標を特定することを目的としている。

Methods

 7つの書誌データベース(例:MEDLINE、PsycINFO)を検索し、さらに含まれる研究および関連するレビューの後方および前方参照リストチェックを行うことで研究を特定した。2人の査読者が独立して研究を選択し、含まれる研究からデータを抽出した。抽出されたデータは、特定された指標を、指標が評価するチャットボットの観点に基づくカテゴリにグループ化することによって、ナラティブアプローチにより統合された。

Results

 検索された1498件の引用のうち、65件の研究がこのレビューに含まれた。チャットボットは27の技術的メトリクスで評価され、それらはチャットボット全体(例:ユーザビリティ、分類器の性能、速度)、応答生成(例:理解度、リアルさ、反復性)、応答理解(例:ユーザーが評価したチャットボットの理解力、単語エラー率、概念エラー率)、美観(例:仮想エージェントの外観、背景色、コンテンツ)に関連するものであった。

Conclusions

 ヘルスチャットボットの研究の技術的指標は多様であり、調査デザインとグローバルユーザビリティの指標が主流であった。標準化の欠如と客観的な指標の少なさは、ヘルスチャットボットの性能を比較することを困難にし、この分野の発展を阻害する可能性がある。私たちは、研究者が会話ログから計算されたメトリクスをより頻繁に含めることを提案する。さらに、チャットボットの研究に含めるための特定の状況に対する推奨事項を備えた技術的なメトリックのフレームワークを開発することを勧める。

Principal Findings

 現在、ヘルスチャットボットの評価には標準的な方法がないことが明らかになった。ほとんどの観点は、自述式アンケートやユーザーインタビューを使って研究されている。一般的な測定基準は、応答速度、単語エラー率、概念エラー率、対話効率、注意推定、タスク完了度である。様々な研究がチャットボットの異なる観点を評価し、直接の比較を複雑にしている。このばらつきの一部は、チャットボットの実装とその明確なユースケースの個々の特性によるものかもしれないが、応答の適切さ、理解度、現実性、応答速度、共感性、反復性などの指標が、それぞれごく一部のケースにしか適用できないとは考えづらい。また、客観的な定量指標(例えば、ログレビューに基づく指標)は、報告された研究において比較的稀にしか使用されていない。したがって、我々は、チャットボットの研究に含めるための特定の状況に対する勧告を伴う技術的なメトリックの評価フレームワークに向けて研究開発を継続することを提案する。

 Jadejaら[81]は、チャットボットの評価について、情報検索(IR)視点、ユーザー体験(UX)視点、言語視点、AI(人間らしさ)視点の4つの次元を紹介している。先行研究[14]では、ヘルスチャットボットは必ずしも情報を取得するためだけに設計されていないため、IRの視点をタスク指向の視点に修正し、さらにシステム品質とヘルスケア品質の視点を加えて、この分類を適応・拡大した。技術的なメトリクスの定義から外れるヘルスケア品質の観点を除くと、このスコーピングレビューの結果は、これらすべての次元が実際にヘルスチャットボットの評価で表現されていることを示している。むしろ問題は、自己報告とUXの視点に偏っていることに加え、何をどのように測定するかが一貫していないことにあるようである。ヘルスチャットボットの品質に特化した標準的な指標と対応する評価ツールを考え出すには、さらなる研究が必要である。

 我々は、ユーザビリティがヘルスチャットボットの最も一般的に評価される観点であることを発見した。System Usability Scale (SUS [82,83])は、ユーザビリティを評価する研究の大部分では使用されていなかった(多くの場合、単一の調査質問が代わりに使用されていた。)が、我々が繰り返し使用されていることを観察した確立されたユーザビリティ尺度である。SUSは、非独占的であり、技術にとらわれず、製品間の比較をサポートするように設計されている[82]。そのため、研究者が評価にSUSを含めることを標準化することで、ヘルスチャットボットのUXのグローバルな評価は、品質と比較可能性が向上する可能性がある。しかし、Holmesらの研究[84]は、ユーザビリティとUXを評価する従来の方法をヘルスチャットボットに適用した場合、それほど正確ではない可能性があることを示した。そのため、ヘルスチャットボットのための適切な指標に向けては、まだ研究が必要である。

 XiaoIce[85]に代表されるように、ソーシャルチャットボットの成功指標としてConversational-turns Per Session(CPS)が提案されている。ヘルスチャットボットの目的はソーシャルチャットボットと同一ではないが、CPSがソーシャルチャットボット領域で標準的な指標として受け入れられるようになれば、健康チャットボットの評価でソーシャルエンゲージメントの次元を評価するための標準指標の有力候補になるだろう。社会的次元に関連する代替・補足的な指標としては、ユーザーにチャットボットの共感度を採点してもらう方法があるが、CPSは客観的・定量的な指標であるという利点がある。インタラクション時間やタスクにかかる時間など、他の客観的かつ定量的な指標もCPSの代替となり得るが、例えばユーザーが他のタスクとチャットボットのインタラクションをマルチタスクしている場合、CPSよりもエンゲージメントの代表度が低くなる可能性がある。ソーシャルエンゲージメントの他に、タスク完了度(会話ログの分析により評価されることが多い)も有望なグローバル指標である。

 標準化のためのさらなる領域は、応答の質であろう。我々は、応答生成は広く評価されているが、非常に多様な方法であることを確認した。反応の生成と理解に関する標準的な尺度が出現すれば、研究の比較可能性が大きく向上する。この分野で有効性が検証された尺度を開発することは、チャットボット研究への有用な貢献となるであろう。

 我々は、適用可能で実用的な評価であれば、健康チャットボットの研究に分類器の性能を含めることを称賛する。難易度の違いにより、ドメイン間で生のパフォーマンス(例えば、曲線下の面積)を比較することはあまり意味がないかもしれない。理想的には、チャットボットのパフォーマンスは、手元のタスクに対する人間の専門家のパフォーマンスと比較されるだろう。さらに、我々は、製品が成熟するにつれて、ヘルスチャットボットの研究における性能測定の進歩の機会があると認識している。初期段階の良い評価指標は、プロダクトが良く機能するために、応答品質と応答理解を評価するものとなるだろう。その後の実験では、自己申告によるユーザビリティや、ソーシャルエンゲージメントの指標の評価を進めることができる。分類器の性能は、臨床結果を評価する試験が必要かどうかを判断するための技術的な性能評価となる。

メンタルヘルスケアのための対話システム:レビュー論文のレビュー① ―患者の見方と意見および有効性と安全性

Preface

 Alaa A. Abd-alrazaq(College of Science and Engineering, Hamad Bin Khalifa University, Qatar)らによる、メンタルヘルスのための対話システムに関する、一連のレビュー論文についてまとめる。それぞれのレビュー論文について、とりあえずAbstractとPrincipal findingsを訳し、それ以外の個人的に面白そうな知見はOther Interesting Findingsにまとめた。今回は、患者の見方と意見に着目したレビューと、有効性と安全性に関するレビューの二つについてまとめる。

(ちなみに、今回扱っている論文はいずれもOpen Accessであり、Creative Commons License (CC-BY)なので、適切な引用方法及びライセンス表示により利用できるが、出版社が権利を有する論文をこのような形でブログで転載すると著作権侵害になり出版社に怒られる可能性があるので気を付けましょう。)

Perceptions and Opinions of Patients About Mental Health Chatbots: Scoping Review

www.jmir.org

Abd-Alrazaq AA, Alajlani M, Ali N, Denecke K, Bewick BM, Househ M Perceptions and Opinions of Patients About Mental Health Chatbots: Scoping Review J Med Internet Res 2021;23(1):e17828, doi: 10.2196/17828, PMID: 33439133, PMCID: 7840290

©Alaa A Abd-Alrazaq, Mohannad Alajlani, Nashva Ali, Kerstin Denecke, Bridgette M Bewick, Mowafa Househ. Originally published in the Journal of Medical Internet Research (http://www.jmir.org), 13.01.2021.

License: Creative Commons Attribution 4.0 (CC-BY)

The following is the edited translation.

Abstract

Background

 チャットボットは、メンタルヘルスケアサービスへのアクセスを改善するために、過去10年間使用されてきた。患者の認識や意見は、ヘルスケアへのチャットボットの導入に影響を与える。メンタルヘルスチャットボットに関する患者の認識や意見を評価するために、多くの研究が行われてきた。著者らの知る限り、メンタルヘルスチャットボットに関する患者の認識や意見をめぐるエビデンスのレビューはない。

Objective

 本研究は、メンタルヘルス用チャットボットに関する患者の認識と意見に関するスコーピングレビューを行うことを目的とする。

Methods

 PRISMA(Preferred Reporting Items for Systematic reviews and Meta-Analyses)extension for scoping reviewsガイドラインに沿ってスコーピングレビューを実施した。研究は、8つの電子データベース(例えば、MEDLINEとEmbase)を検索し、さらに、このレビューに含まれる研究と関連する他のレビューの後方および前方参照リストチェックを行うことによって同定された。合計で2名の査読者が独立して研究を選択し、含まれる研究からデータを抽出した。データは主題分析により統合された。

Results

 検索された1072件の引用のうち、37件のユニークな研究がレビューに含まれた。主題分析では、研究の結果から、有用性、使いやすさ、応答性、理解度、受容性、魅力、信頼性、楽しさ、内容、比較の10のテーマが生成された。

Conclusions

 メンタルヘルスのためのチャットボットについて、患者が全体的に肯定的な認識や意見を持っていることが示された。今後取り組むべき重要な課題は,チャットボットの言語能力であり,想定外のユーザー入力に適切に対処できること,高品質の応答を提供できること,応答に高い多様性があることなどが求められる。臨床に役立てるためには、チャットボットのコンテンツを個人の治療勧告と調和させる方法を見つけなければならない。つまり、チャットボットの会話のパーソナライゼーションが必要である。

Principal Findings

 このレビューの主な発見は、ヘルスケアプロバイダーが長期にわたって提供できないチャットボットの機能があるということだ。これらの機能は、メンタルヘルスのチャットボットにおいて有用であると認識されており、リアルタイムフィードバック、ウィークリーサマリー、日記のような継続的なデータ収集が挙げられる。有用性と使いやすさは、分析された論文で最も包括的に研究されている。全体的に、メンタルヘルスチャットボットの有用性は、患者に高く認識されている。これらの研究によると、患者はチャットボットシステムを使いやすいと感じている。インタラクションが楽しいと思われることと信頼できると認識されることは、チャットボットとインタラクションする際の重要な仲介者である[70]。また同時に、チャットボットは便利で使いやすいと認識されているが,報告された研究の参加者は,それらのシステムの既存の会話の限界も認識していた:会話は浅く,混乱し,または短すぎると認識されていた。これは、今後のメンタルヘルスチャットボットの開発で取り組むべき重要な課題を指摘している。会話の質はまだ改善する必要がある。この文脈では、応答性と応答の多様性という点でのチャットボットの品質が重要な課題である。現在、システムは応答回数がかなり制限されているが、これはLaranjoら[71]がすでに報告しているように、多くのチャットボットの開発初期段階であるためである可能性がある。関連する重要と判断されるもう一つの側面は、提供された情報の品質と治療医の勧告との一貫性である。

Other Interesting Findings

 有用であるためには、ユーザーに複数の方法で応答できる高品質のチャットボットを作成する必要がある。メンタルヘルスチャットボットは、やる気と魅力があると認識され、ユーザーとの関係を構築するために共感的でなければならない。de Gennaroによる研究[76]は、共感的なチャットボットが社会的排除の犠牲者に感情的なサポートを提供する可能性があることを実証し、これを支持している。

 標準的な医療環境における患者-医師、患者-セラピストの関係は、信頼と忠誠心によって特徴付けられる。チャットボットと患者の関係も信頼できるものにするための方策を講じる必要がある。これは、収集した患者データの二次利用について、データ保存や分析手順に関する情報を提供することで実現できるだろう。もう一つのアプローチは、対面式とウェブベースまたはデジタルセラピーを組み合わせたブレンドセラピー[77]で、認知行動療法における費用対効果が高く、利用しやすい形式の可能性を示している。これは、チャットボットがセラピーに関連していなければならないという、もう一つの実用的な意味合いにも対応することになる。特に、チャットボットが提供する推奨事項は、治療を行う医療従事者の推奨事項と一致していなければならない。このため、チャットボットを医療プロセスに統合することが求められ、チャットボットは医療従事者の推奨事項や治療計画を知っておく必要がある。最後に、患者におけるチャットボット利用の受容性を高めるには、医師がそれらのシステムの有用性を納得し、患者に推奨するようにする必要がある。研究によると、有用性を確信している医師がすでに存在することが示唆されている[72]。患者の医師に対する信頼の絆が強いことを考えると,医師がアプリを推奨すれば,患者もその有用性に納得するはずである。

 メンタルヘルスチャットボットの言語能力を向上させる必要性はまだある[71]。ユーザーの入力を理解し、適切に応答する能力を高めなければならない。さらに、チャットボットの応答の多様性を確保するために、動的な応答を生成する方法が必要である。言語的または語彙的な可変性は、ルールベースのチャットボットの知識ベースに追加することができるが、その能力は常に知識ベースの完全性に依存する。知識ベースから応答をわずかに適応または再定式化する方法は、この問題への対処に役立つ可能性がある。ヘルスケア以外のドメインでは、対話の質を向上させるためにクラウドソーシングが適用されている[78]。しかし,ヘルスケアでは,応答や勧告が臨床的なエビデンスに沿ったものであることを保証しなければならないため,データからの学習には注意を払わなければならない。ヘルスチャットボットを学習させるために、どのように臨床的なエビデンスを学習させるかについては、まだ未解決の研究課題である。

 さらに、予期せぬユーザーの入力に対処し、危機的状況を検知する方法も開発しなければならない。メンタルヘルスでは、自殺や自傷の危険性がある人に適切に対応することが重要である[79]。センチメント分析は、自殺や自傷行為に関するソーシャルメディアメッセージの分析に成功したことが証明されている[80]。これらの手法は、健康チャットボットにおいても有用である可能性がある。主な課題は、緊急事態が検出された後に適切な反応をすることである。もう一つの興味深い研究トピックは、個々のユーザーに対するチャットボットのカスタマイズやパーソナライゼーションである。このトピックはまだ初期段階にある[81]。メンタルヘルスのチャットボットが決定木や固定的に実装されたルールベースに依存している限り、特定のユーザーのニーズに適応することはできないだろう。さまざまなタイプのユーザーに対する応答があるように知識ベースを構築することはできるが、これには時間がかかり、常に不完全なものになるだろう。これには、ユーザーとの会話から学習することが助けになるだろう。言語のスタイルや複雑さは、与えられたユーザー入力に基づいて適合させることができる。患者固有の知識、例えば、治療計画に関する知識は、医療記録から取得することができる。このような知識をチャットボットに動的に取り込む方法が求められている。このようにして、チャットボットのコンテンツは、個人のニーズに合わせて適応される。

 メンタルヘルスチャットボットを評価するためには、ベンチマークを作成し、一貫した指標と方法を開発する必要がある。Laranjoら[71]は、ヘルスチャットボットの特徴、現在のアプリケーション、評価指標をレビューした。評価指標は、技術的性能、ユーザーエクスペリエンス、健康研究指標の3種類に大別された。デジタルヘルス介入[82]とヘルスチャットボット[83,84]の評価フレームワークに向けた最初の試みは、最近発表された。考慮したい観点に応じて、異なる指標を用いることができる。例えば、システムの性能と有効性は異なる計算指標により評価される(例えば、使い勝手usability、使いやすさeasse of use、有用性usefulness)。ソフトウェアの品質は、ソフトウェア工学の指標を用いた信頼性、セキュリティ、保守性、効率性によって測定することができる[86]。システムがAIや機械学習の技術を用いる場合、指標は予測や推奨の精度や正確さで構成される。さらに、システムの効率性は、既存のケアモデルと評価・比較されなければならない。アプリの安全な使用に関しては、(1)治療内容の質、(2)機能性、(3)データの安全性と保護という3つの基準で評価する必要がある[87]。

Effectiveness and Safety of Using Chatbots to Improve Mental Health: Systematic Review and Meta-Analysis

www.jmir.org

Abd-Alrazaq AA, Rababeh A, Alajlani M, Bewick BM, Househ M Effectiveness and Safety of Using Chatbots to Improve Mental Health: Systematic Review and Meta-Analysis J Med Internet Res 2020;22(7):e16021, doi: 10.2196/16021, PMID: 32673216, PMCID: 7385637

©Alaa Ali Abd-Alrazaq, Asma Rababeh, Mohannad Alajlani, Bridgette M Bewick, Mowafa Househ. Originally published in the Journal of Medical Internet Research (http://www.jmir.org), 13.07.2020.

License: Creative Commons Attribution 4.0 (CC-BY)

The following is the edited translation.

Abstract

Background

 世界的な精神医療を提供する人手の不足により、精神疾患を持つ人々のニーズに応えるために、チャットボットなどの技術的進歩の活用が求められている。チャットボットは、話し言葉、書き言葉、視覚的な言語を用いて人間のユーザーと会話し、対話することができるシステムである。メンタルヘルスにおけるチャットボット使用の有効性と安全性を評価した研究は数多くあるが、それらの研究結果を概観したレビューはない。

Objective

 本研究は、先行研究の結果をまとめ、概観することで、メンタルヘルスの改善にチャットボットを用いることの有効性と安全性を評価することを目的とした。

Methods

 この目的を達成するために、システマティックレビューを実施した。検索には7つの書誌データベース(例:MEDLINE、EMBASE、PsycINFO)、検索エンジンGoogle Scholar」、収録研究および関連レビューの後方および前方参照リストチェックとした。2名の査読者が独立して研究を選択し、含まれる研究からデータを抽出し、バイアスのリスクを評価した。研究から抽出されたデータは、適宜、ナラティブアプローチおよび統計的手法により統合された。

Results

 検索された1048件の引用のうち、8つのアウトカム(うつ病の重症度: Severity of depression、 心理的幸福度: Psychological wellbeing, 不安の重症度: Severity of anxiety, ポジティブおよびネガティブな感情: Positive and negative affect, 苦痛: Distress, ストレス: Stress, 安全性: Safety, 高所恐怖症の重症度: Severity of acrophobia)に対するチャットボットの使用効果を検討した12件の研究を同定した。弱いエビデンスでは、チャットボットがうつ病、苦痛、ストレス、高所恐怖症の改善に効果的であることが示された。一方、同様のエビデンスによると、主観的な心理的幸福度に対するチャットボットの使用は、統計的に有意な効果を示さなかった。不安の重症度やポジティブおよびネガティブな感情に対するチャットボットの効果については、結果が相反していた。チャットボットの安全性を評価した研究は2件のみで、有害事象や害は報告されていないことから、メンタルヘルスにおいて安全であると結論づけられた。

Conclusions

 チャットボットはメンタルヘルスを改善する可能性がある。しかし、その効果が臨床的に重要であるというエビデンスがないこと、各アウトカムを評価する研究が少ないこと、それらの研究においてバイアスのリスクが高いこと、いくつかのアウトカムで結果が矛盾していることなどから、このレビューにおけるエビデンスは、これを確実に結論づけるには不十分であった。チャットボットの有効性と安全性について確かな結論を出すためには、さらなる研究が必要である。

Principal Findings

 この研究では、チャットボットを使用してメンタルヘルスを改善することの有効性と安全性に関するエビデンスを体系的にレビューした。 8 つのアウトカムに対するチャットボットの使用の効果を調べた 12 の研究を特定した。最初のアウトカム (うつ病) については、4 つの RCT から得られた質の低いエビデンスにより、通常の治療やうつ病の重症度に関する情報よりもチャットボットを支持する統計的に有意な差が示されたが、この差は臨床的に重要ではなかった。 2 つの準実験では、チャットボットの使用後にうつ病のレベルが低下したと結論付けられた。 2 つの研究から得られたエビデンスがナラティブアプローチにより統合されたため、このうつ度合の減少が臨床的に重要であるかどうかは特定できなかった。また 2 つの研究で得られた知見は、アウトカムの測定において深刻なバイアスの影響を受けた可能性がある。メンタルヘルスにおけるチャットボットの有効性を評価したレビューがないことを考慮して、結果を同様の介入 (すなわち、インターネットベースの精神療法的介入) に関する他のレビューと比較した。このレビューにおけるうつ病に対する全体的な効果 (–0.55) は、他のレビューと同等であった。具体的には、Andersson と Cuijpers によって実施されたメタ分析では、セラピストのサポートなしでインターネットベースおよびコンピューター化されたうつ病心理的介入の全体的な効果は 0.25 (95% CI 0.14-0.35) であったが [39]、別のメタ分析では、うつ病のインターネットベースの精神療法介入の合計効果は0.32であった[40]。

 不安に関しては、2 つの RCT から得られた非常に質の低いエビデンスでは、チャットボットと不安の重症度に関する情報との間に統計的に有意な差は示されなかった。対照的に、ある準実験では、チャットボットを使用した後、不安レベルが大幅に低下したと結論付けられた。これらの相反する調査結果は、2 つの理由に起因する可能性がある。第一に、プレテスト-ポストテストの準実験は、選択バイアスに起因する内部妥当性が低いため、介入の効果を見つけるための RCT ほど信頼性が高くない[35,41]。第 2 に、2 つの RCT とは対照的に、準実験 [32] のチャットボットにはバーチャルな外見(身体) が含まれていたため、チャットボットは口頭および非言語的に (体の動きや顔の表情を通じて) ユーザーとコミュニケーションをとることができる。この身体化により、チャットボットとの会話がより共感的になり、ユーザーとの効果的なラポールの構築が促進されたと考えられる[19,42,43]。このレビューのメタアナリシスの結果と、スマートフォンメンタルヘルス介入に関連する別のレビューの結果は矛盾していた。 9 件の RCT のメタアナリシスでは、スマートフォンメンタルヘルス介入を行った後、介入を行わなかった場合と比較して、不安レベルが大幅に減少したことが示された (SMD 0.325、95% CI 0.17-0.48) [44]。これらの相反する結果は、両方のレビューにおける介入の違い (チャットボットと異なるモバイルによる介入) またはメタ分析された研究の数 (2 対 9) のいずれかの結果である可能性がある。

 ポジティブとネガティブな感情に対するチャットボットの効果に関する調査結果は矛盾していた。ある研究では、チャットボットが 2 週間の追跡調査でポジティブおよびネガティブな感情を改善したと結論付けているが [29]、別の研究では、2 週間の追跡調査でチャットボットの有意な影響は見られなかった[28]。 2 つの研究は、研究デザイン、サンプルの特性、コンパレータの特性、および結果の測定に関しては非常に同質であったが、チャットボットの種類とデータ分析の方法が異なっており、これらの違いが矛盾した結果につながった可能性がある。具体的には、最初の研究 [29] のチャットボットは、2 番目の研究 [28] のものよりも高度であった。ユーザーへの応答を生成するために人工知能機械学習を用いており、これにより、より人間らしくなり、ユーザーはより社会的につながっていると感じることができた[5]。 2 つ目の違いについては、最初の研究ではチャットボットがポジティブとネガティブな感情に与える影響をまとめて評価しているのに対し [29]、2番目の研究ではチャットボットがポジティブな感情とネガティブな感情に及ぼす影響を別々に調べていた[28]。

 3つの研究のナラティブアプローチによる統合は、チャットボットと対照群との間で主観的な心理的幸福度に関して統計的に有意な差がないことを示した。有意な差が出なかった理由は、3 つの研究で非臨床サンプルを使用したことで説明できる。言い換えれば、参加者はすでに心理的に良好な状態にあるため、チャットボットを使用する効果はそれほど大きくない可能性がある。

ナラティブアプローチにより統合された 2 つの研究によると、チャットボットは苦痛のレベルを大幅に低下させた。どちらの研究もバイアスのリスクが高かった。したがって、この調査結果は注意して解釈する必要がある。同様の文脈での研究では、私たちの調査結果に匹敵する調査結果が報告された。より正確に言えば、RCT は、オンラインチャットカウンセリングが時間の経過とともに心理的苦痛を大幅に改善したと結論付けた [45]。

このレビューでは、チャットボットは時間の経過とともにストレスレベルを大幅に低下させた。残念ながら、エビデンスにバイアスのリスクが高いため、チャットボットの効果に関して決定的な結論を出すことはできない。

ある RCT によると、チャットボットは高所恐怖症の重症度を軽減するのに効果的だった。 この RCT における高所恐怖症に対するチャットボットの効果サイズ [38] は、メタアナリシスによって報告された恐怖症に対するセラピスト支援の曝露治療の合計効果サイズよりもかなり高かった (2.0 対 1.1) [46]。 これは、チャットボットが、恐怖症の治療においてセラピストが提供する暴露治療と同等か、それよりも優れている可能性があることを示している。

チャットボットの安全性を測定する 2 つの RCT のうち、どちらも、チャットボットがうつ病や高所恐怖症のユーザーの治療に使用された場合、有害事象や害は報告されなかったため、チャットボットはメンタルヘルスに安全に使用できると結論付けた。 ただし、2 つの研究でバイアスのリスクが高いことを考えると、この証拠はチャットボットが安全であると結論付けるのに十分ではない。

Other Interesting Findings

 このレビューでは、チャットボットがうつ病、苦痛、ストレス、および高所恐怖症を改善する可能性があることが分かったが、含まれている研究におけるバイアスのリスクが高く、エビデンスの質が低く、各結果を評価する研究が不足しているため、これらの結果に関する決定的な結論を引き出すことはできなかった。含まれている研究のサンプルサイズが小さいこと、および含まれているいくつかの研究の結果に矛盾があること、このため、ユーザー、医療提供者、治療者、政策立案者、およびチャットボット開発者は、結果を慎重に表示する必要がある。

 このレビューで見つかった弱くて相反する証拠を考えると、ユーザーはメンタルヘルスの専門家の代わりにチャットボットを使用すべきではない。代わりに、医療専門家は、個人が必要に応じて医学的アドバイスを求めることを奨励するために、すでに利用可能な介入の補助として、また利用可能なサポートと治療への道しるべとしてチャットボットを提供することを検討する必要がある。

 このレビューのチャットボットの 3 分の 2 は、定義済みのルールと決定木を使用して応答を生成したが、残りのチャットボットは人工知能を使用していた。ルールベースのチャットボットとは対照的に、人工知能チャットボットは複雑なクエリへの応答を生成し、ユーザーが会話を制御できるようにすることができる [13]。人工知能チャットボットは、ルールベースのチャットボットよりも共感的な行動と人間のようなフィラー言語を示すことができる [19]。これにより、人工知能チャットボットがユーザーとのラポールをより効果的に構築し、それによってユーザーのメンタルヘルスを改善する可能性がある[42]。人工知能チャットボットはルールベースのチャットボットよりもエラーが発生しやすいと主張できるが、これらのエラーは、広範なトレーニングとより多くの使用によって最小限に抑え、減少させることができる [49]。したがって、開発者は人工知能チャットボットに集中して有効性を向上させることが望まれる。

 バイアスの全体的なリスクは、主にアウトカムの測定、報告された結果の選択、および交絡の問題が原因で、含まれるほとんどの研究で高かった。今後の研究は、そのようなバイアスを避けるために、研究を実施および報告する際に、推奨されるガイドラインまたはツール (RoB 2 および ROBINS-I など) に従う必要がある。

 報告方法が不十分なため、メタ分析に多くの研究を含めることができなかった。より高度な研究 (すなわち RCT) を奨励するだけでなく、著者は試験結果の報告においてより一貫性を保つ必要がある。たとえば、このレビューでは、多くの研究が平均、SD、サンプルサイズなどの基本的な記述統計を報告できていなかった。 RCT を報告するための承認されたガイドライン (例: CONSORT-EHEALTH [50]) に研究が準拠していることを確認することは、この分野に大きな利益をもたらすだろう。

 現在のレビューでは、すべての 2 グループ試験の比較対象は介入なしまたは教育のいずれかであった。有望な結果 (例: うつ病、苦痛、高所恐怖症) については、チャットボットを、非同期電子介入や他のタイプのチャットボット (例: ルールベースのチャットボットvs人工知能チャットボット または 身体化されたチャットボット vs 身体化されていないチャットボット) などの他の能動的介入と比較することが望まれる。

 Abd-alrazaq らが実施したスコーピング レビュー [13] によると、チャットボットは、自閉症心的外傷後ストレス障害、物質使用障害、統合失調症認知症など、多くの精神障害に使用されている。現在のレビューでは、これらの障害に使用されるチャットボットの有効性または安全性を評価する研究は見つからなかった。これは、自閉症心的外傷後ストレス障害、物質使用障害、統合失調症、および認知症の患者を対象としたチャットボットの有効性と安全性を調べる差し迫った必要性を浮き彫りにしている。

 このレビューでは、同じ結果を測定するために使用されるツールと研究デザインの不均一性が特定された。たとえば、うつ病の重症度は、PHQ-9、Beck Depression Inventory II、または Hospital Anxiety and Depression Scale を使用して測定された。さらに、介入前後の結果を評価した研究もあれば、介入後にのみ評価した研究もあった。この分野は、研究間の結果の比較と解釈を容易にするために、将来的に共通の一連の結果測定を使用することで前進するだろう。チャットボットの長期的な有効性と安全性を評価した研究は 1 つだけで、参加者は 12 週間追跡された。チャットボットの有効性と安全性の結果は、短期的な調査結果と比較して長期的な調査結果を考慮すると、異なる場合があるため、長期的な結果を評価することが不可欠である。