๐ [The Principles of Diffusion Models] ์๋ฆฌ์ฆ
- 1๏ธโฃ ( Part A ) ์ดํดํ๊ธฐ
- 2๏ธโฃ ( Part B ) CH 2. VAE
- 3๏ธโฃ ( Part B ) CH 2. DDPM (with VAE)
CH2. Variational Perspective: From VAEs to DDPMs
๋ค์ด๊ฐ๊ธฐ์ ์์... manifold hypothesis ๋ ์์๋๋ฉด ์ข์ ๊ฒ ๊ฐ์์ ๋งํฌ ํ๋ ์ฒจ๋ถํ๋ค.
https://devs0n.tistory.com/167

VAE -> DDPM ์ผ๋ก ๋ณํํ๋ ํ๋ฆ์ ์ ์ดํดํ๊ธฐ ์ํด์๋ ์ผ๋จ VAE ์ ๋ํด์ ์๊ณ ์์ด์ผํ๋ค. CH2. ์์๋ VAE ์์๋ถํฐ ์์ํด์ ๋ ผ์๋ฅผ ์ด์ด ๋๊ฐ๋ค๊ณ ํ๋ค. ๋จผ์ ์ ์ฒด์ ์ธ ํ๋ฆ์ ๋จผ์ ๋ณด๋ฉด
1) vae : ๋ฐ์ดํฐ -> z -> ๋ฐ์ดํฐ
x( ๋ฐ์ดํฐ ) -> q(z|x) ( ํ์ต๋ ์ธ์ฝ๋ ) -> z ( latent space ) -> p(x|z) -> x_hat ( ๋ณต์๋ ๋ฐ์ดํฐ ) ํํ์ด๋ค.
์ฐ๋ฆฌ๊ฐ part a ์์๋ ๋งํ๋ฏ์ด. log p(x) ๋ฅผ ๋ชป๊ตฌํ๋๊น EBLO ์๋ lower bound ๋ฅผ ์ต๋ํ ํ๋ ์์ผ๋ก ํ์ตํ๋ค.
( ์ด๋ variational inference ๊ฐ ํต์ฌ์ด๋ค. )
(cf. https://ratsgo.github.io/generative%20model/2017/12/19/vi/ )
2) Hierarchical VAE : z ๋ฅผ ์ฌ๋ฌ์ธต์ ์๋๋ค.
๋๋ต x -> z1 -> z2 .... -> zk ์ด๋ฐ์์ผ๋ก ๋ณํํ๋ค.
๋ณต์กํ ๊ตฌ์กฐ์ ๋ํด์ ํ์ธต์ z ๋ก ๋ถ์กฑํด์ latent ๋ฅผ ์ฌ๋ฌ์ธต์ผ๋ก ์์์ฌ๋ฆฌ๋ ๊ตฌ์กฐ์ด๋ค.
3) Diffusion
" Diffusion ์ ๋งค์ฐ ๊ฑฐ๋ํ Hierarchical VAE ๋ก ๋ณผ ์ ์๋ค " ๊ฐ ์ด๋ฒ ์ฑํฐ์์ ๊ฐ์ฅ ์ค์ํ ๋ด์ฉ์ธ ๊ฒ ๊ฐ๋ค.
๋๋ ๋ง์ฐํ๊ฒ vae ๋ฅผ ๋ฐฐ์ฐ๊ณ diffusion ์ ๋ฐฐ์ ์๋ ์ด๋ป๊ฒ ์ด๊ฑธ ํํํ๋ฉด ์ข์๊น ์์ฒญ ๊ณ ๋ฏผํ๋๋ฐ ๋ค์ ์ฝ๊ณ ๋๋ ์ ์ฒด์ ์ธ ํ๋ฆ์ ์ฝ์ ์ ์์ด์ ๋ง์ด ๋ช ํํด์ก๋ค.
๋ํจ์ ์์ latent ๋
x0 (data ) -> x1 -> .... -> xt ( noise )
๋ค์๊ณผ ๊ฐ์ด ์๊ฒผ๋๋ฐ , xo -> x1 ... -> xt ์ ์ฒด๊ฐ latent hierarchy ๋ผ๊ณ ์๊ฐํ๋ฉด ๋๋๊ฒ์ด๋ค. ( ๋จ๊ณ๊ฐ ๊ต์ฅํ ๋ง์.. )
์ฐจ์ด์ ์ผ๋ก
diffusion์์ ์ธ์ฝ๋๋ ๋ฐ๋ก ํ์ตํ์ง ์๋๋ค. ( ๋ํจ์ ์์ ์ธ์ฝ๋๋ ๋ ธ์ด์ฆ๋ฅผ ์ถ๊ฐํ๋ fixed gaussian noise ์ด๊ธฐ ๋๋ฌธ์ด๋ค. )
๋ง์ง๋ง ๋ฌธ๋จ์์ " ์ด ๋ชจ๋ธ๋ค์ ๋ชจ๋ ELBO ๋ฅผ ์ต์ ํํ๋ likelihood ๊ธฐ๋ฐ ์์ฑ๋ชจ๋ธ์ด๋ค" ๋ผ๋ ๋ฌธ์ฅ์ด ์๋ฏ์ด
vae , diffusion ๋ชจ๋ ELBO ๋ฅผ ์ต์ ํํ๋ ๋ชจ๋ธ์ด๋ค. ์ด์ ํ๋ฒ ์์ธํ ์์๋ณด๋๋ก ํ๊ฒ ๋ค.
2.1 Variational Autoencoder
vae ๋ ๋ณต์กํ ๊ณ ์ฐจ์ ๋ฐ์ดํฐ (x) ์ ๊ตฌ์กฐ (latent) ๋ฅผ ํ๋ฅ ์ ์ผ๋ก ์์ฑํ๋ ๋ชจ๋ธ์ด๋ค.
๊ธฐ์กด autoencoder ๋ reconstruct ๋ ํ์ง๋ง generate ํ๋ ๋ฅ๋ ฅ์ ์๋ค.
X -> encoder -> z - >Decoder -> x_hat
์ด๋ฐ ํ์์ธ๋ฐ ๋ฌธ์ ๋ ๊ณต๊ฐ z ๊ฐ ์๋ฌด๋ฐ ๊ตฌ์กฐ์์ด ํฉ์ด์ ธ์๋ค. ( ๊ฐ ๋ฐ์ดํฐ x ๋ฅผ ์๋ฌด๊ณณ (z) ์๋ค ๋ฃ์ด๋๊ณ ๊ทธ๋ฅ ๊ทธ ์์น์์ ์ ๋ณต์๋๊ธฐ๋ง ํ๋ฉด๋จ -> ์ด๋ฌ๋ฉด ๋ถํฌ๊ฐ scatter ํ๊ณ hole ์ด ๋งค์ฐ ๋ง์ด ์๊ธธ ๊ฒ์ด๋ค.๋ฐ๋ผ์interpolation ํ๋ฉด ์ด์ํ ์ํ์ด ๋์ค๊ธฐ ์ฝ๋ค.)
๊ทธ๋์ random z -> ์๋ฏธ ์๋ ์ด๋ฏธ์ง ์ผ ๋ฟ๋ง ์๋๋ผ z ๊ณต๊ฐ์ด ๋ถ์ฐ์ ์ ์ด๊ณ interpolation ๋ ๋ถ๊ฐ๋ฅํ๋ค. ์ด๋ฐ ๋ฌธ์ ์ ์ ํด๊ฒฐํ๊ฒ vae ์ด๋ค.

1. VAE ์ ํต์ฌ ์์ด๋์ด
" latent space ์ ๊ตฌ์กฐ๋ฅผ ๋ฃ์ "
z ๊ณต๊ฐ์ด ์ฐ์์ ์ด๊ณ smooth ํ๊ฒ ์๊ธฐ๋๋ก ๊ฐ์ ํ๋ค. ์ด๋ฅผ ์ํด์ latent ๋ฅผ ํ๋ฅ ๋ณ์๋ก ์ ์ธํ๋ค.
z ~ N( 0 , I )
x ~ pฯ(x | z) ( decoder )
์ด๋ ๊ฒ ๋๋ฉด prior p(z) = N(0,I) ๋ latent ๋ฅผ ์์ถํ๋ coordinate ๋ฅผ ์ ๊ณตํ๋ค.
์์ฑํ ํ๋ก์ธ์ค์์๋
1. ๋ฌด์์๋ก latent z ๋ฅผ ๋ฝ์ ( z~ N(0,I) )
2. z ๋ฅผ decoder ์ ํต๊ณผ์์ผ ์ด๋ฏธ์ง ์์ฑ ( x ~ pฯ(x | z) )
z ๋ฅผ ํ์ค ์ ๊ท ๋ถํฌ์์ ๋ฝ์๊ฑธ decoder p ๋ฅผ ํตํด x ๋ฅผ ์์ฑํ๋ ๊ฒ์ด ๋ชฉ์ ์ด๋ค.
์ด๋ ๊ฒ ๋ง๋ฌด๋ฆฌ๊ฐ ๋๋ฉด ์ข๊ฒ ์ง๋ง pฯ(x) ๋ฅผ ๊ณ์ฐํ๋ ๊ฒ์ด ๋๋ฌด ์ด๋ ต๋ค.

๋ค์๊ณผ ๊ฐ์ ๊ผด๋ก ์ ๋ถ์ ๊ณ์ฐํด์ผ ํ๋๋ฐ , ์ด๋ถ๋ถ์ ๊ณ์ฐ์ ํ์ค์ ์ผ๋ก ๋ถ๊ฐ๋ฅํ๋ค. ์กฐ๊ธ ๋ ์์ธํ ์์๋ณด์.
Construction of Encoder (Inference Network).
vae ์์๋ x ๋ก๋ถํฐ z๊ฐ ๋ญ์๋์ง ์๊ณ ์ถ์ ๊ฒ์ด๋ค. ( ๊ธฐ๋ณธ ๊ฐ์ ์ด ๋ฐ์ดํฐ x ๊ฐ latent space z ์์ ๋์๋ค๊ณ ๊ฐ์ ํ์ผ๋ )
๊ทธ๋ฌ๋ฉด ์์ฐ์ค๋ฝ๊ฒ x(์ด๋ฏธ์ง) ๊ฐ ์ฃผ์ด์ก์๋, ์ด๋ค z ๊ฐ x ๋ฅผ ๋ง๋ ๊ฒ์ธ์ง ์๊ณ ์ถ์ด์ง๋ค. (x->z)
-> ์กฐ๊ธ ๋ ์ฝ๊ฒ ์ด์ผ๊ธฐ ํ์๋ฉด ์ฐ๋ฆฌ๋ ์ด๋ฏธ์ง๋ ๊ด์ธก์ ํ๊ณ , ์ด๋ฏธ์ง๋ฅผ ๋ง๋ z ( latent ) ๋ ๋ญ์์ง ?
(p (z|x) ( posterior ๋ฅผ ์๊ณ ์ถ์ ๊ฒ์ด๋ค. )

์๋๋ ๋ฒ ์ด์ง์ ์ ํตํด์ ํด๋น ๊ณ์ฐ์ ํ๊ณ ์ถ์ผ๋, ํ์ค์ ์ผ๋ก ๊ตฌํ๋ ๊ฒ์ด ๋ถ๊ฐ๋ฅํ๋ค.
์ฌ๊ธฐ์ p(x) ๊ฐ ๊ณ์ฐ์ด ๋ถ๊ฐ๋ฅํ๋ฐ

1. z ๊ฐ ๋งค์ฐ high dimension ์ด๊ธฐ ๋๋ฌธ์ ํ์ค์ ์ผ๋ก ์ ๋ถ์ด ๋ถ๊ฐ๋ฅํ๋ค.
2. p(x|z) ๋ neural network ์ธ๋ฐ ๊ณ์ฐ์ด ๋งค์ฐ ๋ณต์กํ๋ค.
๋ ์ํ์ ์ธ ์ด์ ๋ก ๋ถ๊ฐ๋ฅํ๋ค.
( ์ง๊ด์ ์ผ๋ก ์ดํด๋ฅผ ํด๋ณด์๋ฉด ์ด๋ฏธ์ง (x) ๊ฐ ์๊ณ ์ด ์ด๋ฏธ์ง๋ฅผ ์ด๋ฃจ๊ณ ์๋ค๊ณ ๊ฐ์ ๋๋ ์ฌ๋ฌ๊ฐ์ง ์์ (ex : ์ , ํฌ๊ธฐ , ๊ตฌ๋ ,,,,, ) ๋ฑ๋ฑ ์จ๊ฒจ์ง ์์ (Z) ๊ฐ ์๋ค๊ณ ๊ฐ์ ํด๋ณด์. ์ฐ๋ฆฌ๋ ๊ทธ๋ฌ๋ฉด p(z|x) ๊ทธ๋๊น ์ด๋ฏธ์ง๋ฅผ ๋ณด๊ณ ๋์ ์ด๋ฐ ์์๊ฐ ์ด๋ค๊ฒ์ธ์ง ๊ตฌํ๊ณ ์ถ์ ๊ฒ์ด๋ค.
bayes rule ๊ณ์ฐ์์ ๋ณด๋ฉด , ๋ชจ๋ ์กฐํฉ (z) ๋ฅผ ํ๋์ฉ ๋ค ๋ฃ์ด๋ณด๊ณ -> ์ด๋ ์กฐํฉ์ด ์ด๋ฏธ์ง(x) ๋ฅผ ๋ง๋ค ํ๋ฅ ์ด ๋์์ง๋ฅผ ๊ณ์ฐํ๋๊ฒ

์๋ ์ด ์์์ด๋ค. z ๊ฐ ๊ฑฐ์ 32~256 ์ฐจ์์ด๋ฉด ์กฐํฉ์๊ฐ ์ด๋ง์ด๋งํ๊ฒ ๋ง์์ง ๋ฟ๋ง ์๋๋ผ decoder ๋ ์ ๊ฒฝ๋ง์ด๋ผ ๊ณ์ฐ์ด ๋งค์ฐ ๋ณต์กํ๋ค. ์ฆ, ๊ฐ๋ฅํ z๋ฅผ ํ๋ํ๋ ๋ฃ์ด๋ณด๋๊ฑด ๋๋ฌด ๋นํจ์จ์ ์ด๋๋ค.
--> ๋ชจ๋ ๊ฐ๋ฅํ z์ ๋ํด์, ๊ทธ z๊ฐ x๋ฅผ ๋ง๋ค์ด๋ผ ํ๋ฅ ์ ๋ค ๋ํ ๊ฐ์ด ์ปค์ง๋๋ก ๋ชจ๋ธ์ ํ์ต์ํค๊ณ ์ถ์๊ฒ. )
๊ทธ๋์ vae ์์๋ ๋ค์๊ณผ ๊ฐ์ด ํด๊ฒฐํ๋ ค๊ณ ํ๋ค.
1. ์ง์ง ์ ๋ต ๋ถํฌ : p(z|x) -> ํ์ง๋ง ์์์ ์ ๊ฐํ๋ฏ์ด ๊ณ์ฐ์ด ๋ถ๊ฐ๋ฅ.
2. x->z mapping ์ด ๋ฐ๋์ ํ์ํ๊ธดํจ.( ๋ฐ์ดํฐ๋ฅผ ๋ณด๊ณ latent space ๋ฅผ ํ์ต์์ผ์ผ ํ๋๊ฑด ์๋ช ํ๋ฏ๋ก)
3. ๊ทธ๋ผ p(z|x) ์ ์ ์ฌํ q(z|x) ๋ฅผ ์ฌ์ฉํ์.

๋ค์ vae ํ์ดํ๋ผ์ธ์ ๋ณด๋ฉด
x ----> [encoder] ----> z ----> [decoder] ----> xฬ
๋ค์๊ณผ ๊ฐ์ ํ๋ฆ์ ๊ฐ์ง๊ณ ์๋๋ฐ encoder ๋ ์ค์ x ๋ฅผ ๋ฐ์์ z ๋ฅผ ๋ง๋๋ ๋คํธ์ํฌ์ด๋ค. ์ด๋ฅผ ์ด์ฉํด์ p(z|x) ๊ฐ ์ง์ง ์์ฑ๋ชจ๋ธ์ด๊ธฐ๋ ํ์ง๋ง ์ฐ๋ฆฌ๋ ํ๋์ ๊ฐ์ ์ฆ, encoder ์์ ์ฌ์ฉํ๋ q(z|x) ๋ฅผ inference ์ญํ ๋ก ๋ณด๋ ๊ฒ์ด๋ค.
๋ค์ ์ ๋ฆฌํ์๋ฉด ์ฐ๋ฆฌ๋ ์ฌ์ค ๋์ฝ๋๋ก z ์์ x ๋ฅผ ์์ฑํ๋๊ฒ๋ง ์๊ณ ์ถ์๋ฐ , ํ์ค์ ์ผ๋ก ์ด๋ฅผ ๊ณ์ฐํ๊ธฐ ์ด๋ ค์ฐ๋ ์ธ์ฝ๋๋ฅผ ๋ถ์ฌ์ ์ฌ๊ธฐ์ ๋์จ q(z|x) ๋ฅผ ์์ฑ ๋ชจ๋ธ๋ก ๊ทผ์ฌ์์ผ์ ์ฌ์ฉํ๊ธฐ ์ํด encoder ๋ฅผ ๋ถ์๋ค ๋ผ๊ณ ์ดํดํ๋ฉด ๋๊ฒ ๋ค. ( ์ค์ ๋ก vae ๋ ผ๋ฌธ์๋ ๊ทธ๋ ๊ฒ ๋์ ์์๋ ๊ฒ ๊ฐ๊ธฐ๋ ํ๋ค. )
์ ๋ฆฌํ๊ณ ๋ค์ ์ฝ์ด๋ณด๋ ์คํ๋ ค ๋ ํท๊ฐ๋ฆฌ๋ ๊ฒ๋ ๊ฐ์์ ์ผ๋จ ์งง๊ฒ ์ ๋ฆฌ๋ฅผํ๋ฉด
- ๋๋ ์๋ decoder p(x|z) ๋ง ํ์ํจ -> ํ์ค์ ์ธ ๋ฌธ์ ๋ก ๋ถ๊ฐ๋ฅ
- p(x|z) ๊ทผ์ฌํ q(z|x) ๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด encoder ๋ฅผ ๋ง๋ ๋ค.
q1. ๊ทธ๋ ๋ค๋ฉด ์ q(z|x) ๋ ํ๋ฅ ๋ถํฌ ํํ์ฌ์ผ ํ๋ ?
a1. ๋งจ ์ฒ์ ์ด์ผ๊ธฐํ๋ autoencoder ์ vae ๊ฐ ๋ค๋ฅธ ์ด์ ๋
" latent space ๋ฅผ ํ๋ฅ ์ ์ผ๋ก ์ ๋๋ , smooth ํ ๊ณต๊ฐ์ผ๋ก ๋ง๋๋ ๊ฒ " ์ด๋ค.
์ด๋ฅผ ์ํด์ 2๊ฐ์ง๊ฐ ํ์ํ๋ค.
1. x ๋ฅผ ํ๋์ z ๋ก ๋ณด๋ด์ง ์๊ณ , z์ "๋ถํฌ" ๋ก ๋ณด๋ด์ผํจ
2. latent ๊ฐ prior N ( 0 ,I ) ์ align ๋๋๋ก KL term ์ด ํ์ํ๋ค ( ์ด๋ถ๋ถ์ ๋ค์์ ๋ ์์ธํ ๋ค๋ฃฌ๋ค. )
2.1.2 Training via the Evidence Lower Bound (ELBO)

ELBO ์ ๋ํ ์์ฝ์ ์ด๋ ๋ค.
์๋ ๋ชฉํ: log pฯ(x) (log-likelihood)๋ฅผ ์ต๋ํํ๊ณ ์ถ๋ค.
→ ์์์ ์ค๋ช
ํ๋ฏ์ด ๊ณ์ฐ์ด ์ ๋๋ค.
→ ๊ทธ๋์ ๋์ ELBO๋ผ๋ “๊ณ์ฐ ๊ฐ๋ฅํ lower bound”๋ฅผ ๋ง๋ค์ด์ ๊ทธ๊ฑธ ์ต๋ํํ๋ค.
→ ์ด ELBO๋
- ์ฌ๊ตฌ์ฑ ํญ (reconstruction)
- ์ ์ฌ KL (latent regularization) ๋์ ํฉ์ผ๋ก ์ชผ๊ฐ์ง๋ค.
๊ทธ๋ฆฌ๊ณ log pฯ(x) − ELBO = KL(q(z|x) || pฯ(z|x)) ๋ผ๋ ์ฌ์ค ๋๋ฌธ์,
ELBO๋ฅผ ํค์ด๋ค๋ ๊ฑด
๊ณง
(1) log-likelihood๋ฅผ ํค์ฐ๋ฉด์,
(2) ๋์์ q(z|x)๊ฐ ์ง์ง posterior p(z|x)์ ๊ฐ๊น์์ง๋๋ก ๋ง๋๋ ๊ฒ์ด๋ค.

์ฐ๋ฆฌ๊ฐ ์๋ ์ต๋ํ ํ๊ณ ์ถ์ ์์์ ๋ค์๊ณผ ๊ฐ๋ค. ํ์ง๋ง ์์์๋ ์ด์ผ๊ธฐ ํ๋ฏ์ด pฯ(x, z) = p(z) pฯ(x|z) ์ ์์์ด high dimension z ์ ๋ํด์ ์ ๋ถ์ด ๋ถ๊ฐ๋ฅํ๋ค
-> ์์ฐ์ค๋ฝ๊ฒ log p(x) ์ ์ต๋ํ๊ฐ ๋ถ๊ฐ๋ฅํด์ง๋ค.
๊ทธ๋์ ์ง์ ๊ณ์ฐ์ ๋ชปํ๋ฉด ๊ณ์ฐ ๊ฐ๋ฅํ lower bound ๋ฅผ ๋ง๋ค์๋ ๊ฒ์ด eblo ์ ๊ฐ์ฅ ํฐ ๋ชฉ์ .
<์์์ ๊ฐ>


- ๊ธฐ์กด ์์์ ๋ถ์/๋ถ๋ชจ์ qθ(z|x)๋ฅผ ๊ณฑํ๋ค๊ฐ ๋๋๋ค.

๊ทธ๋ฌ๋ฉด ๊ฒฐ๊ตญ log E[...] ๊ผด๋ก ๋ณํ์ ํ ์ ์๋ค.
-> ๊ทธ๋ฆฌ๊ณ JENSEN ๋ถ๋ฑ์์ ์ฌ์ฉํ๋ค. ( log E[Y] > E[logY] )
๊ทธ๋ฆฌ๊ณ ์์์ ๋ค์๊ณผ ๊ฐ์ด ์ ์ํ๋ค ( ELBO ์ ์ ์ )

<ELBO>
ELBO ์์์ ๋ํด์ ์กฐ๊ธ ๋ ํ์ด๋ณด์.

- pฯ(x,z) = p(z)pฯ(x|z)
- ๋ฐ๋ผ์ log pฯ(x,z) = log pฯ(x|z) + log p(z)
- ์ด๋ ๊ฒํ๋ฉด 3๊ฐ ํญ์ผ๋ก ์ชผ๊ฐค ์ ์๋ค.

์ฌ๊ธฐ์ ๋๋ฒ์จฐ์ ์ธ๋ฒ์งธ ํญ์ ํฉ์น๋ฉด KL ์ด ๋์จ๋ค.
( kl(p||q ) -> sum ( p(x) log ( p(x) / q(x) ) ) == E[log(q) / log(p) ] ) ์ ์๋๋ก ์ฐ๋ฉด ์ด๋ ๊ฒ ์ ๋ฆฌ๊ฐ๋๋ค.




๋ฐ๋ผ์ ์ต์ข ํํ๋ ๋ค์๊ณผ ๊ฐ๋ค. ( ์ฑ ์ ์์๊ณผ ๋์ผํด์ง )
์ด๋ ๊ฒ ๊ตฌ์ฑํ๋ฉด ๊ณ์ฐ์ด ๋งค์ฐ ์ฌ์์ง๋ค.
(1) Reconstruction term : z ๋ก๋ถํฐ ์ผ๋ง๋ x ๋ฅผ ์ ๋ณต์ํ๋์ง ?
-> ์ค์ํ์ :
๋ง์ฝ์ reconstruction term ์ผ๋ก๋ง ์ต์ ํ ํ๋ฉด ๋ฌธ์ ๊ฐ ์๊ธฐ๋๋ฐ ์ด ๋ฌธ์ ๋ ๋ฐ๋ก "Memorization" ์ด๋ค.
์ด ํญ๋ง ์์ผ๋ฉด encoder ๊ฐ q(z|x) ๋ฅผ ๊ทธ๋ฅ ์๋ฌด๋ ๊ฒ๋ ์ธ์ฝ๋ฉํด๋ decoder ๋ ๊ทธ๋ฅ p(x|z) ๋ฅผ ํตํด์ ๊ทธ๋ฅ x ๋ฅผ ๋ณต์ฌ๋งํด๋ loss ๊ฐ ์์์ง๊ฒ ๋๋๋ฐ ์ด๋ ๊ฒ ๋๋ฉด latent space ์ ๊ตฌ์กฐ๊ฐ ์์ด์ง๊ฒ ๋๋ค.
-> autoencoder ๋ฅผ ์์ฑํ ๋ชจ๋ธ๋ก ์ธ์ ์๋ ์ด์ ์ด๊ธฐ๋ ํ๋ค.
-> vae ์์๋ kl regularization ์ ๋ฐ๋์ ๋ฃ์ด์ผ ํ๋ค. ( vq-vqe ์์๋ ์ด๋ฐ ํ์์ผ๋ก ๊ตฌ์ฑ๋๋ค. )
(2) Latent Regularization (KL Divergence) : q(z|x) ๊ฐ prior (z) ์ ์ผ๋ง๋ ๊ฐ๊น์ด์ง ?
q(z|x) = N(0,I) ์ ์ ์ฌํด์ง๋๋ก ๊ฐ์ ํ๋ Regularization ์ด๋ค.
์๋ฌธ : This regularization shapes the latent space into a smooth and continuous structure, enabling meaningful generation by ensuring that samples drawn from the prior can be reliably decoded. This regularization shapes the latent space into a smooth and continuous structure, enabling meaningful generation by ensuring that samples drawn from the prior can be reliably decoded.
์ฌ๊ธฐ์ ๋จผ์ 2๊ฐ์ง ๊ฐ์ ์ด ํ์ํ๋ค.
1. p(z) ๋ ์ ๊ท ๋ถํฌ๋ก ๊ฐ์ ํ๋ค.
- ์ด๋ ๊ฒ prior ์ ๋จ์ํ gaussian ์ผ๋ก ๊ณ ์ ํด์ผ KL ๊ณ์ฐ์ด ์ ๋๊ณ , latent space ๊ฐ ์ ์ ๋๋ ๋ฟ๋ง ์๋๋ผ ์ํ๋ง์ด ์ฝ๋ค.
- ์ฌ๊ธฐ์ log p(z) ( ์ํ z ๋ฅผ ์ ๊ท๋ถํฌ p ์ ๋ฃ์ ๊ฐ. )
2. Encoder q(z|x) ๋ tractable ํ ๋ถํฌ์ฌ์ผ ํจ.

- encoder ๊ฐ ์ ๊ท๋ถํฌ ํํ๋ฅผ ์ถ๋ ฅํ๋ค๊ณ ๊ฐ์ ํ๋ค.
- ์ด๋ฐ ๊ฐ์ ์ด ์์ด์ผ KL(q||p) ๋ฅผ closed form ์ผ๋ก ๊ณ์ฐ์ด ๊ฐ๋ฅํจ

( ์์ธํ๊ฑด cf. https://di-bigdata-study.tistory.com/5
๊ทธ๋ผ ๋ฌด์จ ์ญํ ์ ํ๋ ?
1)
latent space ๋ฅผ N(0,I) ์ฃผ๋ณ์ผ๋ก ์ ๋ ฌ ์ํฌ ์ ์๋ค.
๋ชจ๋ x ๋ฅผ encoding ํ q(z|x) ๊ฐ gaussian prior ์ฃผ๋ณ์ ๋ชจ์ด๊ฒ ๋๋ค.
encoder ๊ฐ ๋ง๋ ๋ถํฌ ( q(z|x) ) ๊ฐ prior p(z) ์์ ๋ฉ์ด์ง ์๋ก ํจ๋ํฐ๊ฐ ์ปค์ง๋๊น ๋์ด ๋ถํฌ๊ฐ ๋น์ทํด์ง.
์์ฐ์ค๋ฝ๊ฒ q(z|x) = N(0,I) ์ ๋ถํฌ์ ๋น์ทํด์ง. ( ์์์ ๋ณด๋ฉด Mu ์ 0์ผ๋ก ๊ฐ๊ณ , sigma ๋ 1๋ก ๊ฐ์ผ ๊ฐ์ฅ ๊ฐ์ด ์์์ง๋ค. )
๋ฐ๋ผ์ latent space ๋ฅผ ๋งค๋๋ฌ์ด ๊ตฌ์กฐ๋ฅผ ๋ง๋ ๋ค. ( smooth + continuous )
2)
์์์ ์ค๋ช ํ memorization๋ฅผ ๋ฐฉ์งํ ์ ์๋ค.
q(z|x) ์ variance ๊ฐ 0 ์ ๊ฐ๊น์ฐ๋ฉด ์ฌ์ค์ latent space ๊ฐ ๋ฐ์ดํฐ๋ณ๋ก ๋ฉ๋ฆฌ ํฉ์ด์ ธ์๋ ๋ชจ์์ผํ ๋ฐ
์ด๋ฌ๋ฉด KL term ์ด 0์ผ๋ก ๊ฐ๋ฉด loss ๊ฐ ์ฌ์ค์ ๋ฌดํ๋๋ก ์ปค์ง๋ค.
๋ฐ๋ผ์ ๋๋ฌด ์คํํ ์ธ์ฝ๋ฉ์ด ๊ธ์ง๋๋ค.
( ์ถ๊ฐ : memorization ์ ๊ฒฐ๊ตญ ๋ฐ์ดํฐ๋ณ๋ก ๋ฉ๋ฆฌ ํฉ์ด์ ธ์ ํด๋น ๋ฐ์ดํฐ๋ง ์ธ์ฐ๋ latent space ๋ฅผ ๊ฐ์ง๊ฒ์ด์ง๋ง , ์ฐ๋ฆฌ๋ ๊ทธ๋ฌ์ง์๊ณ ๋๋ฌด ๋ฐ์ดํฐ๋ฅผ ๋ฉ๋ฆฌ ( var ) ๋์ง ์๊ธฐ ๋๋ฌธ์ธ ๊ฒ์ผ๋ก ์ดํดํ๋ฉด ์ฝ๋ค. )
Information-Theoretic View: ELBO as a Divergence Bound.
ELBO ๋ฅผ ์ ๋ณด์ด๋ก - ๋ถํฌ ๊ด์ ์์ ์ด๋ป๊ฒ ํด์ํ๋ ์ง์ ๋ํ ์น์ ์ด๋ค.
์ด์ฏค์์ ํ๋ฒ ๋ค์ remind ( with gpt )



๊ฒฐ๊ตญ MLE ๋๊น , ๋ชจ๋ธ ๋ถํฌ ๊ฐ ๋ฐ์ดํฐ ๋ถํฌ p_data(x)๋ฅผ ์ผ๋ง๋ ์ ๊ทผ์ฌํ๋์ง๋ฅผ ๋ํ๋ด๋ ๊ฒ์ด๋ค.
ํ์ง๋ง p_phi(x) ๊ฐ KL ๋ ์ง์ ๊ณ์ฐํ๊ธฐ ํ๋ค๊ธฐ ๋๋ฌธ์ -> ELBO ๋ฅผ ์ฌ์ฉํ๋ค.

-----------------------------------------------------------------------------------------------------------------------------------------------------------------

๋ค์ ์ฑ ์ ๋ด์ฉ์ผ๋ก ๋์๊ฐ ๋ณด์๋ฉด , ์์ ์์ MIN ํ๋ ๊ฒ์ ๊ฒฐ๊ตญ
" ๋ชจ๋ธ๋ถํฌ P_phi(x) ๊ฐ ๋ฐ์ดํฐ ๋ถํฌ p_data(x) ๋ฅผ ์ผ๋ง๋ ์ ๊ทผ์ฌํ๋๊ฐ "
๋ฅผ ๋ํ๋ด๋ error ์ด๋ค. ( ์ ์์ ์ต์ํ ํ๋๊ฒ ์๋ ๋ชฉ์ ์ด์์ง๋ง )
ํ์ง๋ง ์์์ ์ค๋ช ํ๋ฏ์ด P_phi(x) ๊ฐ ๊ตฌํ๊ธฐ ํ๋ค์ด์ ์ด KL ๋ ์ง์ ๊ณ์ฐํ๊ธฐ ํ๋ค๋ค. -> ๋ฐ๋ผ์ ELBO ๋ฅผ ์ฌ์ฉํ๋ค.
์ฑ ์์ ๋ joint ์ ๋ํด์ ์ด์ผ๊ธฐํ๋ค

generative joint ๋ prior p(z) = N(0,I) / p_phi(x|z) ๋ vae decoder
inference joint ๋ ๋ฐ์ดํฐ + encoder ์ฎ์ด์ ( p_data(x) ) / (q(z|x) )
( ์ฝ๊ฒ์ด์ผ๊ธฐํ๋ฉด generatvie joint = decoder ์์ ์์ฑํ๋๊ฑฐ / inference joint = ์ธ์ฝ๋์์ z ์ถ๋ก ํ๋.. )
์ฌ๊ธฐ์
- qθ(x,z)๋ ์ง์ง ๋ฐ์ดํฐ p_data(x)์
๊ทธ ๋ฐ์ดํฐ์ ๋ํ ์ถ๋ก ๋ latent qθ(zโฃx)๋ฅผ ๋ถ์ธ joint. - pฯ(x,z)๋ prior + decoder๋ก ์ ์๋๋ ๋ชจ๋ธ์ joint.
์ด ๋์ ์ฐจ์ด๋ฅผ joint KL๋ก ๋น๊ตํ๊ฒ ๋ค๋ ๊ฒ ์ถ๋ฐ์ ์ด๋ค.
" chain rule for KL divergence "

์ง๊ด์ ์ผ๋ก ์ดํด๋ฅผ ํด๋ณด์๋ฉด
(์ข์ธก marginal kl) : ๊ฒ์ผ๋ก ๋ณด์ด๋ x ๋ถํฌ๋ผ๋ฆฌ๋ง ๋น๊ต
(์ฐ์ธก joint kl) : (x,z) joint ๊น์ง ๋น๊ตํ๋ KL ์ x ๋ฟ๋ง์๋๋ผ x์ ์ฐ๊ฒฐ๋ z ๊น์ง ํฌํจํด์ ๋ ํ๋ํ๊ฒ ๋น๊ตํจ.
--> ์ด๊ฑธ ์ ๋ณด ์ด๋ก ์ผ๋ก ์ดํด๋ฅผ ํด๋ณด๋ฉด
joint KL ์ด marginal KL ๋ณด๋ค ํฌ๋ค
์ฆ, ์ ๋ณด๊ฐ ๋ง์ผ๋ฉด -> mismatch ๋ฅผ ๋ ๋ง์ด ์ก์๋ธ๋ค.
์ด ๋ง์ ์ข ํ์ด์ ์จ๋ณด๋ฉด
์ฐ๋ฆฌ ๊ด์ธก ๊ฐ๋ฅ ๋ฐ์ดํฐ : x ( ์ด๋ฏธ์ง )
latent : z ( ์ ์ฌ ์ฝ๋ )
๋๊ฐ์ง ํ๊ฐ ๋ฐฉ๋ฒ์ด ์กด์ฌํ๋ค
1. ๊ฒ๋ชจ์ต๋ง ํ๊ฐํ๊ธฐ ( marginal KL )
-> ์ด๋ฏธ์ง x ์์ฒด์ ๋ถํฌ๋ง ๋ง๋์ง ๋ณธ๋ค
-> ๊ทธ๋์ z ์ธ์๊ฐ ์๋ ๊ฒ์ด๊ณ .
2. ๊ฒ + ์ ( joint KL )
-> ์ด๋ฏธ์ง x ์ ๋ณธํฌ๋ ๋ณด๊ณ ๊ฐ ์ด๋ฏธ์ง x ์ ๋ํด์ ์ด๋ค z ๊ฐ ๋ถ๋์ง ๋ณธ๋ค
so -> D_kl ( q (x,z ) | p (x,z) )
2.
์ง๊ด์ ์ผ๋ก ์ joint ๊ฐ ํญ์ ํฌ๊ฑฐ๋ ๊ฐ๋๋ฉด
joint kl ์ ์ ๋ณด๋ฅผ ๋ ๋ง์ด ๋น๊ตํ ์ธ์์ด๋ค ( ๋งํ์๋ฉด ๋ฌธ์ ๋ต + ํ์ด๊ณผ์ ์ ๋ณธ ๊ฒฝ์ฐ์ด๊ณ )
marginal kl ์ ๋ต๋ง ๋ณธ ๊ฒฝ์ฐ์ด๋ค.
๋น์ฐํ ๋ต๋ง ๋ณธ ๊ฒฝ์ฐ์ ๊น์ด๋ ์ ์๊ฐ ๋ํ์ง ์๊ฒ ๋๊ฐ ?? -> ๊ทธ๋ ๋ค๋๊ฑด loss ๊ฐ ๋ ์๋ค๋ ์๋ฏธ๊ฐ ๋๋ค ! ( ์๋ฌ๋ฅผ ๋ ์ก์๋)
๋ฐ๋ผ์ joint kl ์ด ์ง๊ด์ ์ผ๋ก ํญ์ ๋ ํฌ๊ฑฐ๋ ๊ฐ์ ์ ๋ฐ์ ์๋ค.
3.
์ ๋ณด๊ฐ ๋ง์ผ๋ฉด -> mismatch ๋ฅผ ๋ ์ ์ก๋๋ค.
์ด ์ญ์ ๊ฝค๋ ์ง๊ด์ ์ผ๋ก ๋น์ฐํ๋ฐ,
marginal ๊ฐ์ ๊ฒฝ์ฐ ( x ๋ง ๋ณผ๋ ) ๊ฒฐ๊ตญ ๋ต๋ง ๋ณด๋ ๊ฒฝ์ฐ์ด์ง๋ง
joint kl ์ ๊ฒฝ์ฐ ( x + latent z ๋ฅผ ๋ณด๋๊น ) ๋ ์ ๋ณด๊ฐ ๋ง์ ๊ฒ์ด๋ค. ๋ latent z ๊ฐ์ ๊ฒฝ์ฐ๋ x ๋ฅผ ์ผ์ข ์ ์ถ์ํํด์ ์์ถ์์ผ ๋ ๋ถ๋ถ์ด๋ผ๊ณ ์๊ฐํ์๋ ๋ผ๊ณ ์๊ฐํด๋ณธ๋ค๋ฉด ๋ class ๋ฅผ ์ธ๋ฐํ๊ฒ ์กฐ์ ํ ์ ์๋ ๋ถ๋ถ์ด ๋ง์๊ฒ์ด๋ค.
๋ฐ๋ผ์ mismatch ๋ฅผ ์ก๊ธฐ ๋ ์ฝ๋ค. !
4.
total error bound ์ธ ์ด์
์ฌ์ค ์ฐ๋ฆฌ๋ DKLโ(pdataโ(x)โฅpฯโ(x)) ๋ฅผ ์ค์ด๊ณ ์ถ์๋ฐ ์ด๊ฒ ์ด๋ ค์ฐ๋ DKLโ(qθโ(x,z)โฅpฯโ(x,z)) (joint KL) ์ ๋ค๋ฃจ๋ ๊ฒ์ด๋ค.
joint KL ์์์ ๋ณด๋ฉด ๋ชจ๋ธ๋ง์๋ฌ + ์ถ๋ก ์๋ฌ๊ฐ ํฉ์ณ์ ๋ค์ด๊ฐ ์๋ค
๋์ด ํฉ์ณ์ ๋ค์ด๊ฐ ์์ผ๋ TOTAL ERROR ๋ก ๋ณด๋๋ฏ.
๋,
๋ชจ๋ธ๋ง์๋ฌ ( ์๋ ์๊ณ ์ถ์๊ฐ ) <= JOINT KL (TOTAL ERROR )
์ด๋๊น joint KL ์ ๋ชจ๋ธ๋ง ์๋ฌ์ ๋ํ upperbound ์ด๋ค.
( ์ด๊ฑธ kl ์ chain rule ์ด๋ผ๊ณ ๋ถ๋ฅธ๋ค๊ณ ํ๋ค. )
์์ ์ ๊ฐ๋ฅผ ํด๋ณด๋ฉด.

True modeling error :
- ๋ชจ๋ธ์ marginal pฯ(x)๊ฐ ์ง์ง ๋ฐ์ดํฐ ๋ถํฌ๋ฅผ ์ผ๋ง๋ ์/๋ชป ๋ง์ถ๋์ง ( ์ฐ๋ฆฌ๊ฐ ์๋ ๊ตฌํ๊ณ ์ ํ๋ ๊ฐ )
- ์ด๋ฏธ์ง x ์์ฒด์ ๋ถํฌ๋ง ๋ง๋์ง ๋ณธ๋ค <-> ( joint kl ์ ์ด๋ฏธ์ง ๋ถํฌ x ๋ ๋ณด๊ณ / ์ด๋ฏธ์ง x ์ ๋ํด์ ์ด๋ค z ๊ฐ ๋ถ๋์ง ๊ฐ์ด ๋ณธ๋ค. )
inference error :
- ์ง์ง posterior pฯ(zโฃx)์ encoder๊ฐ ๊ทผ์ฌํ qθ(zโฃx)์ฌ์ด์ mismatch ํ๊ท ( posterior p ์ encoder ๊ฐ ๊ทผ์ฌํ ๊ฐ q ์ฌ์ด์ ํ๊ท )
์ด๋ inference error ๋ ํญ์ 0๋ณด๋ค ํฌ๋๊น ๊ฒฐ๊ตญ

๋ค์๊ณผ ๊ฐ์ ์์์ด ๋ง์กฑ๋๋ ๊ฒ์ด๋ค
( joint KL = modelling error + inference error )
-> joint KL ์ด ํฌ๋ค๋๊ฑด
1. ๋ชจ๋ธ ์์ฒด์ ์ค๋ฅ๊ฐ ์๊ฑฐ๋
2. encoder ์ถ๋ก ์ด ๋ณ๋ก์ฌ์ posterior ๋ฅผ ์ ๋ชป๋ฐ๋ผ ๊ฐ๊ฑฐ๋
๋์ค ํ๋ ํน์ ๋๊ฐ๋ผ๋ ์๋ฏธ๋ฅผ ๊ฐ์ง๊ฒ ๋๋ค.
ELBO and inference error
๊ฒฐ๊ตญ ์์ ์์์ ์ ๋ฆฌํ๋ฉด ๋ค์๊ณผ ๊ฐ์ด ์ ๋ฆฌํ ์ ์๋ค.
( cf.


์ด ์์ ๋ป์ ์ฆ
"log likelihood ๋ ELBO ์ฌ์ด์ gap ์ด posterior ๋ฅผ ์ผ๋ง๋ ์ ๊ทผ์ฌํ๋์ง ( inference error ) ๋ ์ ํํ ๊ฐ๋ค "
log p_phi(x) : ์ฐ๋ฆฌ๊ฐ ์ ๋ง ์ต์ ํ ํ๊ณ ์ถ์ likelihood
L_elbo : ์ฐ๋ฆฌ๊ฐ ์ค์ ๋ก ์ต์ ํ ํ ์ ์๋ ๊ฒ
๋์ ์ฐจ์ด๊ฐ D_KL = ( encoder posterior ๊ทผ์ฌ ์ค์ฐจ )
๊ทธ๋์ ๊ฒฐ๊ตญ elbo ๋ฅผ ์ต๋ํ ํ๋ค๋๊ฑด
1. ๋ชจ๋ธ๋ง ์๋ฌ๋ฅผ ์ค์ด๋ ๋ฐฉํฅ์ผ๋ก ํ์ตํ๋ฉด์
2. inference error ๋ ์ค์ด๋ ๊ฒ์ด๋ค.
ELBO ๊ฐ logp(x) ์ ๊ฐ๊น์ ์ง์๋ก -> q(z|x) ๊ฐ ์ง์ง posterior p(z|x) ์ ๊ฐ๊น์ ์ง๊ณ ์๋ค๋ ์๋ฏธ์ด๋ค.
2.1.3 Gaussian VAE

์๋ฌธ ๊ทธ๋๋ก๋ผ ์ค๋ช ํ ๊ฒ ๋ง์ง๋ ์๋ค
(1) ์ธ์ฝ๋ qθ(zโฃx)
- latent z ๋ฅผ ๋์ ํด์ ๋ฐ์ดํฐ๋ฅผ ์ค๋ช ํ๋ ค๊ณ ํจ
- x ( ์ด๋ฏธ์ง ) -> z ์ ๋ํ ๋ถํฌ๋ฅผ ๋ฑ๋๋ค.
- ๋ค๋ง ์ฌ๊ธฐ์ ์ธ์ฝ๋๋ ํ๋์ ๋ฒกํฐ๊ฐ ์๋๋ผ ์ ๊ท๋ถํฌ ( ํ๊ท + ๋ถ์ฐ ) ์ ๋ด๋ณด๋ธ๋ค๋์ .
(2) ๋์ฝ๋ p(x|z)

- latent -> data ๊ณต๊ฐ์ผ๋ก ๊ฐ๋ ๋์ฝ๋
- ๋ถ์ฐ์ ๊ณ ์ ๋ ์ค์นผ๋ผ + isotropic ๊ฐ์ฐ์์ ( ์ฆ , ๋์ฝ๋๋ ํ๊ท ๋ง ํ์ตํ๊ณ , ๋ถ์ฐ์ ๊ทธ๋ฅ ์์๋ก ๊ณ ์ ํจ. )
Reconstrction term ์ด MSE ๊ฐ ๋๋ ์ด์ ?

์์์ ์ดํด๋ณธ Reconstrcution term ์ด๋ค. pฯ(xโฃz)๊ฐ ๊ฐ์ฐ์์์ด๋ผ๊ณ ํ์ผ๋ ๋ก๊ทธ๋ฅผ ํ๋ฉด

๋ค์๊ณผ ๊ฐ์ด ๋๋๋ฐ C ๋ ์์๋๊น ๋ฌด์ ( ๋ณ์ phi ๋ theta ๋ ๊ด๋ จ์ )
so ,

๋ฐ๋ผ์ MSE ๊ผด์ด ๋๋ค.
๊ฒฐ๊ตญ EBLO ๋ฅผ ์ต๋ํ ํ๋ ค๋ฉด -> Reconstruction term ์ ์ต๋ํ ํด์ผํ๋๊น -> (์๊ฐ ๋ง์ด๋์ค mse ์ด๋ฏ๋ก ) -> ์ ํญ์ ์ต์ํ ํ๋ ๊ฒ๊ณผ ๊ฐ๋ค๊ณ ๋ณผ ์ ์๋ค.
( ์๋ฏธ์ ์ผ๋ก๋ ๋์ฝ๋๊ฐ x ๋ฅผ ์ ์ฌ๊ตฌ์ฑํ๋๋ก mse recontrctuion loss ๋ฅผ ์ต์ํ ํ๋๊ฑฐ๋ ๊ฐ๋ค๊ณ ์ดํดํ๋ฉด ๋๊ฒ ๋ค. )

๋ฐ๋ผ์ ELBO ์์์ ๋ค์ ์ ๋ฆฌํ๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
1 : mse
2 : KL
( KL ์ ์์์๋ ์์ ํ์ง๋ง q ์ p ๋๋ค ๊ฐ์ฐ์์ ์ด๋๊น closed-form ์ผ๋ก ๊ณ์ฐํ๊ธฐ ํธํ๋ค. )
<์ ๋ฆฌ>
VAE loss๋ฅผ ์ง๊ด์ ์ผ๋ก ๋ณด๋ฉด:
- ์ฌ๊ตฌ์ฑ ํญ:
- ์คํ ์ธ์ฝ๋์ฒ๋ผ x๋ฅผ ์ ๋ณต์ํ๋๋ก MSE๋ฅผ ์ค์.
- z์์ ๋ค์ x๋ก ๊ฐ ๋ ์ ๋ณด ์์ค์ด ์ ๊ฒ.
- KL ์ ๊ทํ ํญ:
- ๊ฐ x์ ๋ํด ๋์ค๋ qθ(zโฃx)๊ฐ prior N(0,I)์ ๋๋ฌด ๋ค๋ฅด์ง ์๊ฒ ๋ฌถ์ด์ค.
- ๊ทธ๋์ “latent space ์ ์ฒด๊ฐ ๋ถ๋๋ฝ๊ณ ์ฐ์์ ”์ด ๋๊ณ ,
- ๋ฌด์์๋ก z∼N(0,I)์ํ ๋ฝ์์ ๋์ฝ๋์ ๋ฃ์ด๋ ๊ทธ๋ด๋ฏํ x๊ฐ ๋์ค๊ฒ ๋จ.
์์ฝํ๋ฉด:
VAE ํ์ต = “์ฌ๊ตฌ์ฑ MSE๋ฅผ ์ค์ด๋ฉด์, latent ๋ถํฌ๋ฅผ ํ์ค์ ๊ท์ ๊ฐ๊น๊ฒ ๋ง๋๋ ์ต์ ํ”
2.1.4 Drawbacks of Standard VAE
2.1.5 (Optional) From Standard VAE to Hierarchical VAEs

Hierarchical Variational Autoencoders (HVAEs) ๋ latent ๋ฅผ ๋จ๊ณ์ ์ผ๋ก ์์ ๋ชจ๋ธ์ด๋ค. ์์ชฝ latent ๊ฐ ์ ๋ฐ์ ์ธ ๊ตฌ์กฐ ( ๊ฑฐ์น์ ๋ณด ) ๋ฅผ ๊ฐ์ง๊ณ ์๋์ชฝ latent ๊ฐ ์ ์ ๋๋ํ ์ผ์ ๋ด๋นํ๋ ํ์์ด๋ค.( ๋ฒ์จ ๋ํจ์ ๊ณผ ๋ชจ๋ธ์ด ๋งค์ฐ ๋น์ทํ๊ฒ ์๊ฒผ๋ค) ์กฐ๊ธ ๋ ์์ธํ ์์ ํด ๋ณด๊ฒ ๋ค.
HVAE vs VAE
๊ธฐ์กด vae
- latent ๊ฐ ํ๊ฐ

HAVE
- latent ๋ฅผ ์ฌ๋ฌ ์ธต์ผ๋ก ์์ ( z1 ,,,, zl )
- ์๋ฏธ๋ฅผ ์ด๋ ๊ฒ ํ์ ํ๋ฉด ํธํ๋ค
-> z_L : ์ ์ผ ์ , ๊ฐ์ฅ ์ถ์์ ์ธ ์ ๋ณด ( ์ ์ฒด ๋๋ )
-> z_1 : ๊ฐ์ฅ ์๋ , ๋ํ ์ผ์ ๊ฐ๊น์ด ์ ๋ณด
-> x ์ค์ ์ด๋ฏธ์ง
HVAE's Modeling
(Decoding)

HVAE ์์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ์์์ ๋ค์๊ณผ ๊ฐ๋ค. ๋ณต์กํด๋ณด์ด์ง๋ง ๋งค์ฐ ๋จ์ํ๋ฐ ( ์์ชฝ์ ๊ทธ๋ฆผ์ ๋ณด๋ฉด ๋ฐ๋ก ์ดํด๊ฐ ๋๋ค )
๋งจ์ Latent ๋ฅผ ์ํ๋งํ๋ค => z_L ~ p(z_L)
๊ทธ ๋ค์ ์์์๋ถํฐ ์๋๋ก ์ฐจ๋ก๋๋ก latent ๋ฅผ ์์ฑํ๋ค

๊ทธ๋ฆฌ๊ณ ๋ง์ง๋ง์ผ๋ก ์ค์ ๋ฐ์ดํฐ ์์ฑ์ ํ๋ฉด ( p(x|z1) ) ๋์ด๋ค. ( ๋งค์ฐ ์ง๊ด์ ์ด๋ค. )
coarse -> fine ๋ฐฉํฅ์ผ๋ก ์ ๋ณด๋ฅผ ์ ์ ํ์ด๋ด๋ฉด์ x ๋ฅผ ๋ง๋ ๋ค.
(top-down)
(Encoding)

์ธ์ฝ๋ฉ์ ๋ฐ๋๋ก x ์์ ์์ํด z ๋ฅผ ์ถ์ถํ๋ค์ ๋จ๊ณ๋ฅผ ์ฌ๋ผ๊ฐ z_l ๊น์ง ๋๋ฌํ๋ค.
( bottom-up )
Observation 2.1.1
๋ ์ด์ด๋ฅผ ์ฌ๋ฌ ์ธต ์์ผ๋ฉด, ๋ฐ์ดํฐ ์์ฑ์ ๊ฑฐ์น ๊ฒ ์์ํด์ ์ ์ ๋ํ ์ผ์ ์ถ๊ฐํ๋ ๋ฐฉ์์ผ๋ก ๋ง๋ค ์ ์๋ค.
์ด๋ ๊ฒ ๊ตฌ์ฑํ๋ฉด
1. ๊ณ ์ฐจ์ ๋ณต์กํ ๊ตฌ์กฐ๋ฅผ ๋๋์ด์ ์ฒ๋ฆฌํ ์์๋ค
- ์ด๋ฏธ์ง ์ ์ฒด ๊ตฌ์กฐ / ๋ก์ปฌ ์ ๊ฐ์ z ์ ๋ฃ์ผ๋ ค๋ฉด vae ๋ก๋ ํํ๋ ฅ์ด ๋ถ์กฑํ ์ ์์
2. Neural net ๊ณผ ์์ด๋์ด๊ฐ ๋ง์
- ์ฐ๋ฆฌ๊ฐ ๊ธฐ์กด์๋ ์ด๋ค ๋ณต์กํ ๊ฒฝ๊ณ๋ฅผ ํํํ๊ธฐ ์ํด์ ๋คํธ์ํฌ๋ฅผ ๊น๊ฒ ์์๋๋ฐ vae ๋ ๋ง์ฐฌ๊ฐ์ง๋ก latent ๊ณต๊ฐ ์์ฒด๋ฅผ ๊น๊ฒ ๋ง๋ค์๋ค๊ณ ์๊ฐํ๋ฉด ๋๋ค. ( ๊ฝค๋ ๋๋ผ์ด ์ง๊ด์ธ๊ฑฐ๊ฐ๋ค )
์ด๋ฐ ์์ด๋์ด๊ฐ ๊ฒฐ๊ตญ score-base / diffusion ๋ชจ๋ธ์์ ์ด๋ฐ coarse to fine ( or multi scale ) ๊ตฌ์กฐ๊ฐ ๋ง์ด ๋์ค๊ฒ ๋๋ ์์ด์ด๋ค.ใ ฃ
HVAE ์ EBLO ์ ๋

( ์์์ ๋จธ๋ฆฌ๊ฐ ์ด์ง ์ํ๋ ์ ๊น ๊ฑด๋ ๋๊ฒ ๋ค )
Why Deeper Networks in a Flat VAE are Not Enough.
HVAE๋ latent ๊ณต๊ฐ์ ๊น๊ฒ ๋ง๋ ๊ฑฐ๋ผ๊ณ ๋ณผ ์ ์๋ค → ๊ทธ๋ ๋ค๋ฉด encoder/decoder๋ง ๊น๊ฒ ๋ง๋ค๋ฉด ์ ๋๋? → ์ ๊ตณ์ด z๋ฅผ ์ฌ๋ฌ ์ธต์ผ๋ก ์ชผ๊ฐ์ ์ฐ๋๊ฐ? ๋ผ๋ ์๋ฌธ์ผ๋ก ์ด์ด์ง ์ ์๋ค. ( ๋๋ ๊ทธ๋ฌ๊ณ )
์ด์ฐจํผ ๊น๊ฒ ๋ง๋๋ ๊ฒ ๋ชฉ์ ์ด๋ผ๋ฉด, ๊ตณ์ด z1,…,zL๋ฅผ ๋์ ํ ํ์ ์์ด encoder์ decoder ๋คํธ์ํฌ๋ง ๊น๊ฒ ๋ง๋ค๋ฉด ๋๋ ๊ฒ ์๋๊น? ๋ผ๋ ์๋ฌธ์ด ์๊ธด๋ค. ๋ณธ ๋ ผ๋ฌธ์์๋ ๋จ์ํ encoder / decoder ๋ฅผ ๊น๊ฒ ๋ง๋ ๋ ๊ฒ๊ณผ latent ๊ณ์ธต ์์ฒด๋ฅผ ๋์ ํ๋๊ฒ ๋ณธ์ง์ ์ผ๋ก ๋ค๋ฅด๋ค๊ณ ์ค๋ช ํ๋ค.
1. variational family ๊ฐ ๋๋ฌด ๋จ์ํ๋ค.

๊ธฐ์กด vae ์์ encoder ๋ ๋ณดํต ์ด๋ ๊ฒ ๋๋๋ฐ ์๋ฌด๋ฆฌ encoder ๋คํธ์ํฌ(µ, σ๋ฅผ ์ถ๋ ฅํ๋ NN)๋ฅผ ๊น๊ฒ ๋ง๋ค์ด๋
a. ๋ถํฌ์ ํํ ์์ฒด๋ " ๋๊ฐ ๊ณต๋ถ์ฐ ๊ฐ์ง ํ๋์ ๊ฐ์ฐ์์ " ์์ ์๋ฒ์ด๋๊ณ
b. ์ผ๋ง๋ ๋ณต์กํ ํจ์๋ฅผ ์ฐ๋ ๊ฒฐ๊ตญ ํ ๋ฉ์ด๋ฆฌ๋ก ๊ทผ์ฌํ๊ฒ ๋๋ค.
( ๋ง์น ๋ณต์กํ ํจ์๋ฅผ ๋ง๋๋ ค๋ฉด activate function ์ ๋ฌ์ผ ํ๋ฏ์ด .. )
ํ์ง๋ง ์ง์ง posterior P_(z|x) ๋ ๋ณดํต Multi-peaked ์ผ ์์๋ค.
ํ์ง๋ง ์ฐ๋ฆฌ๊ฐ ์ฐ๋ encoder q(z|x) ๋ ํ๋์ ๊ฐ์ฐ์์์ด๋ผ ์ฌ๋ฌ ํผํฌ๋ชจ์์ธ p(z|x) ๋ฅผ ์ ํํ ๋ฐ๋ผ๊ฐ ์ ๊ฐ ์๋ค.
์ฆ, ๋คํธ์ํฌ๋ฅผ ์๋ฌด๋ฆฌ ๊น๊ฒ ์ฌ์ฉํด๋ M , var ๋ง ๋ ์ ์ถ์ ํ๊ฒ ํ ๋ฟ์ด์ง ๋ถํฌ์ family ์์ฒด๋ ์๋ฐ๊พผ๋ค.
2. ๊ทธ๋์ Hierarchy ๋ ๊ธฐ์กด ๋ฐฉ์๊ณผ ๋ฌด์์ด ๋ค๋ฅธ๊ฐ ?

1. Reconstruction loss
E[logp (x|z)] -> ์ฌ์ ํ ์ผ๋ง๋ z1 ๋ก x ๋ฅผ ์ ์ฌ๊ตฌ์ฑ ํ๋๋ ?
2. KL
KL divergence ์ชฝ์ ํ์ธํด๋ณด๋ฉด ๊ฐ ๋ ๋ฒจ๋ณ๋ก 1:1 ๋ก KL ๋ก ๋ฌถ์ฌ์ ์๋ก ๋ง์ถ๋๋ก ํ์ตํ๋ค. ( ์์์ ํ์ด๋ณด์๋ฉด.. )

์ด๋ ๊ฒ ๋๋ฉด ๊ฐ inference conditional ์ด ์์ ์ top-down ์ด๋ ์ง๊ณผ aligned ๋๋๋ก KL ์ธต์ด ๋๋ ์ง๋ค.
์ด๋ ๊ฒ ๋๋ฉด
1. ์ ๋ณด ํจ๋ํฐ๊ฐ ์ธต๋ณ๋ก ๋ถ์ฐ๋๊ณ
2. ํ์ต ์ ํธ๊ฐ local ํ๊ฒ ์ ๋ฌ๋๋ค. ( ๊ฐ ๋ ๋ฒจ์์ ์๊ธฐ ์/์๋๋ก๋ง ๋ง์ถ๋ฉด ๋จ )
๊ทธ๋์ ์ด๋ฐ ํจ๊ณผ๋ encoder/decoder ๋ฅผ ๊น๊ฒ ์๋๊ฒ์ผ๋ก๋ ์ป์ ์ ์๊ณ
latent ๋ณ์๋ค์ ๊ทธ๋ํ ๊ตฌ์กฐ๋ฅผ ๋ฐ๊ฟ์ผ์ง๋ง ์ป์ ์ ์๋ค.
ํ์ง๋ง ์์ง ํ๊ณ๊ฐ ์กด์ฌํ๋๋ฐ
1. joint optimize ๋ฌธ์
- ์ต์ ํ ํ๊ธฐ ํ๋ค๋ค.
2. gradient ๊ฐ ๊น์ latent ๊น์ง ์ฝํ๊ฒ ๋๋ฌํ๋ค. ( gradient vanshing ์ฒ๋ผ.. )
3. conditionals ๊ฐ ๋๋ฌด ๊ฐ๋ ฅํ๋ฉด ์ ๋ ๋ฒจ์ด ์ฃฝ๋๋ค. (์กฐ๊ฑด ๋ถํฌ๊ฐ ๋๋ฌด expressive ํ๋ฉด ๋ฎ์ ๋ ๋ฒจ์์ ๋ค ํด๊ฒฐํด๋ฒ๋ฆฌ๋๊น ์์ ๋ ๋ฒจ์์ ํ ๊ฒ ์๋ค. )
๊ทธ๋์ ์ด๋ฐ HVAE ๊ฐ ์ด๋ป๊ฒ Diffusion ์ผ๋ก ์ด์ด์ง๋์ง ๋ง๋ณด๊ธฐ๋ก ์ดํด๋ณด์๋ฉด
1. HVAE ์ coarse -> fine ์์ด๋์ด๋ ์ฑ์ฉ
2. ํ์ง๋ง posterior ๋ ๊ฐ์ด ํ์ตํด์ผํด์ ๋ถ์์ ํ๋ค.
Diffusion ์
1. enoding ์ ๊ณ ์ ํด๋ฒ๋ฆฐ๋ค -> encoding ์ noise ๊ณผ์ ์ผ๋ก ํต์ผ
2. ํ์ต์ ์ค๋ก์ง reverse generative ๊ณผ์ ๋ง ํ๋ค. -> ( ๊ทธ๋์ posterior collapse ๋ฌธ์ ์ ๊ฐ์ ํ์ต ๋ถ์์ ์ฑ์์ ๋ง์ด ์์ ๋กญ๋ค )
3. ๊ทธ๋ผ์๋ ๋ถ๊ตฌํ๊ณ progressive , multi-step (HAVE ์คํ์ผ )์ ์ ์ง๋ฅผ ํ๋ค.
๋ด์ฉ์ด ๋งค์ฐ ๊ธธ์ด ์ก๊ธฐ ๋๋ฌธ์ ๊ธ์ ํ๋ ๋ ์์ฑํด์ ์ดํ์ DDPM ๋ถ๋ถ์ ํ์ด๋๊ฐ๋ณผ๊น ํ๋ค.
VAE ๋ฅผ ๊น๊ฒ ๋ฐฐ์ฐ๊ณ ๋์๋ ์ ๋ฐ์ ์ผ๋ก ์ด๋ค ๋๋์ผ๋ก latent ๋ฅผ ๋ค๋ฃจ๋์ง ๊ฐ์ ์ตํ๊ณ , HVAE ๋ฅผ ํตํด์ ์ด latent ๋ฅผ ์๋๋ค๋๊ฒ ์ด๋ค ์๋ฏธ์ธ์ง ๋งค์ฐ ๋ช ํํ๊ฒ ์๊ฒ ๋์๋ค.
์ด๋ฅผ ๋ฐํ์ผ๋ก DDPM ์ ์ดํดํด๋ณด์.
cf. vae ์ ๋ํ ์ ํ๋ธ ์์
'Paper' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
| Why Diffusion Models Donโt Memorize ๋ฐํ์๋ฃ ๋ฐ ์๋ฌธํด๊ฒฐ.. (0) | 2025.12.09 |
|---|---|
| TRELLIS ๋ฐํ์๋ฃ (0) | 2025.11.26 |
| [The Principles of Diffusion Models] ( Part A ) (0) | 2025.11.03 |
| DINO v3 ๋ฐํ์๋ฃ (0) | 2025.08.30 |
| VGGT ๋ฐํ์๋ฃ (0) | 2025.08.16 |