當前位置: 華文問答 > 科學

Diffusion 理論學習4

2024-05-15科學

關於sde的討論在前面章節

ODE視角下的Diffusion Model

理論推導

主要參考:https:// kexue.fm/archives/9228

首先先對dirac函數做個分析

p(y) = E_{y\sim p(y)}[\delta(x-y)]

這是因為

p(\mathbf{x}) = \int \delta(\mathbf{x} - \mathbf{y}) p(\mathbf{y}) d\mathbf{y} = \mathbb{E}_{\mathbf{y}}[\delta(\mathbf{x} - \mathbf{y})]

直觀理解,當我們想知道一個分布在某個取值處的概率,可以看作是delta函數的期望值,相當於你采樣這個分布,只有當x=y的時候才是1,那麽求期望就恰好是取值為y的概率值。

接下來,考慮p_{t+\Delta t}(\mathbf{x})

\begin{aligned} p_{t+\Delta t}(\mathbf{x}) \\ =&\,\mathbb{E}_{\mathbf{x}_{t+\Delta t}}\left[\delta(\mathbf{x} - \mathbf{x}_{t+\Delta t})\right] \\ =&\mathbb{E}[\delta(\mathbf{x} - (\mathbf{x}_t + \mathbf{f}_t(\mathbf{x}_t) \Delta t + g_t \sqrt{\Delta t}\mathbf{\varepsilon}))] \\ \approx&\, \mathbb{E}_{\mathbf{x}_t, \mathbf{\varepsilon}}\left[\delta(\mathbf{x} - \mathbf{x}_t) - \left(\mathbf{f}_t(\mathbf{x}_t) \Delta t + g_t \sqrt{\Delta t}\mathbf{\varepsilon}\right)\cdot \nabla_{\mathbf{x}}\delta(\mathbf{x} - \mathbf{x}_t) + \frac{1}{2} \left(g_t\sqrt{\Delta t}\mathbf{\varepsilon}\cdot \nabla_{\mathbf{x}}\right)^2\delta(\mathbf{x} - \mathbf{x}_t)\right] \\ =&\, \mathbb{E}_{\mathbf{x}_t}\left[\delta(\mathbf{x} - \mathbf{x}_t) - \mathbf{f}_t(\mathbf{x}_t) \Delta t\cdot \nabla_{\mathbf{x}}\delta(\mathbf{x} - \mathbf{x}_t) + \frac{1}{2} g_t^2\Delta t \nabla_{\mathbf{x}}\cdot \nabla_{\mathbf{x}}\delta(\mathbf{x} - \mathbf{x}_t)\right] \\ =&\,p_t(\mathbf{x}) - \nabla_{\mathbf{x}}\cdot\left[\mathbf{f}_t(\mathbf{x})\Delta t\, p_t(\mathbf{x})\right] + \frac{1}{2}g_t^2\Delta t \nabla_{\mathbf{x}}\cdot\nabla_{\mathbf{x}}p_t(\mathbf{x}) \end{aligned}

其中第一行是Dirac函數的性質,第二行是diffusion的SDE的定義(dx = f(x,t)dt + g(t) dw ),第三行是dirac函數在x-x_t 處的展開,並且忽略o(\Delta t) 以上的項。第四行對\epsilon 求期望,註意一階矩是0,二階矩是1。最後一行是Dirac函數分布性質的逆運用。

兩邊除以\Delta t 取極限,就得到Focker-Plank方程式

\frac{\partial}{\partial t} p_t(\mathbf{x}) = - \nabla_{\mathbf{x}}\cdot\left[\mathbf{f}_t(\mathbf{x}) p_t(\mathbf{x})\right] + \frac{1}{2}g_t^2 \nabla_{\mathbf{x}}\cdot\nabla_{\mathbf{x}}p_t(\mathbf{x})

我們把g_t 裏拆分出一部份\sigma_t 來,得到

\begin{aligned} \frac{\partial}{\partial t} p_t(\mathbf{x}) =&\, - \nabla_{\mathbf{x}}\cdot\left[\mathbf{f}_t(\mathbf{x})p_t(\mathbf{x}) - \frac{1}{2}(g_t^2 - \sigma_t^2)\nabla_{\mathbf{x}}p_t(\mathbf{x})\right] + \frac{1}{2}\sigma_t^2 \nabla_{\mathbf{x}}\cdot\nabla_{\mathbf{x}}p_t(\mathbf{x}) \\ =&\,- \nabla_{\mathbf{x}}\cdot\left[\left(\mathbf{f}_t(\mathbf{x}) - \frac{1}{2}(g_t^2 - \sigma_t^2)\nabla_{\mathbf{x}}\log p_t(\mathbf{x})\right)p_t(\mathbf{x})\right] + \frac{1}{2}\sigma_t^2 \nabla_{\mathbf{x}}\cdot\nabla_{\mathbf{x}}p_t(\mathbf{x}) \end{aligned}

接下來,考慮FP方程式的逆運用,我們上面算FP方程式的時候,除了x_{t+1} 是按照SDE寫的,剩下的全是dirac函數的技巧,因此FP方程式其實和SDE是對應的,現在我們重新整理之後,得到了一個新的FP方程式,那麽回去看對應的SDE又是啥呢?答案是

d\mathbf{x} = \left(\mathbf{f}_t(\mathbf{x}) - \frac{1}{2}(g_t^2 - \sigma_t^2)\nabla_{\mathbf{x}}\log p_t(\mathbf{x})\right) dt + \sigma_t d\mathbf{w}

但是要註意,剛才得到新FP的過程是恒等變換(拆出一部份\sigma^2 ,另外別忘了FP的左側描述的是p_t(x) ,所以說不同的SDE是可以有相同的p_t(x) 的。另外,這個新過程和原來過程相比,隨機部份的系數從原來的g_t^2 縮小為\sigma^2 ,這就說明了保持p_t(x) 不變可以實作縮小變異數,我們發現極限情況下,如果讓\sigma^2=0 ,那就是上一小節末尾提到的確定性過程。

所以極端情況的ODE為

d\mathbf{x} = (f_t(\mathbf{x}) - \frac{1}{2}g_t^2\nabla_{\mathbf{x}}\log p_t(\mathbf{x})) dt

另外註意ODE因為變化過程是確定的,因此不存在什麽正向逆向的問題,只需關於t按照不同的方向積分即可。

直觀理解

x 服從的分布是p_0(x) ,p_t(x) 則是對x做變換後t 時刻x 的邊緣分布,我們希望定義一個變換過程,使得p_t(x) 在不斷變化直到最後分布變成Gaussian。而FP方程式的左邊恰好就是這個變換過程所對應的x 的分布隨時間變化的導數。這個導數的值由sde的f和g的兩項決定的。因此,只要能保證FP-方程式左邊不變,就說明p_t(x) 的變化方式一樣。

而保證FP-方程式左邊不變,相應的方程式右邊可以做很多恒等變換,每個恒等變換的背後都產生了新的f和g,背後是一個新的sde,甚至極限情況是ODE。