[Daily morning study] Diffusion ๋ชจ๋ธ ์๋ฆฌ (Stable Diffusion)
#daily morning study
Diffusion ๋ชจ๋ธ์ด๋
Diffusion ๋ชจ๋ธ์ ์ด๋ฏธ์ง๋ฅผ ์ ์ง์ ์ผ๋ก ๋ ธ์ด์ฆ๋ก ๋ง๊ฐ๋จ๋ฆฐ ๋ค, ๊ทธ ์ญ๊ณผ์ ์ ํ์ต์์ผ ๋ ธ์ด์ฆ์์ ์๋ณธ ์ด๋ฏธ์ง๋ฅผ ๋ณต์ํ๋ ์์ฑ ๋ชจ๋ธ์ด๋ค. Stable Diffusion, DALLยทE 2, Imagen ๋ฑ์ด ๋ชจ๋ ์ด ๊ณ์ด์ด๋ค.
ํต์ฌ ์์ด๋์ด๋ ๋ ๋จ๊ณ๋ก ๋๋๋ค.
- Forward process (ํ์ฐ) โ ์๋ณธ ์ด๋ฏธ์ง์ ๊ฐ์ฐ์์ ๋ ธ์ด์ฆ๋ฅผ T ์คํ ์ ๊ฑธ์ณ ์กฐ๊ธ์ฉ ๋ํด, ๊ฒฐ๊ตญ ์์ํ ๊ฐ์ฐ์์ ๋ ธ์ด์ฆ ์ํ๋ก ๋ง๋ ๋ค.
- Reverse process (๋ณต์) โ ๋ ธ์ด์ฆ ์ํ์์ ์์ํด ์คํ ๋ง๋ค ๋ ธ์ด์ฆ๋ฅผ ์์ธกยท์ ๊ฑฐํ๋ฉฐ ์๋ณธ์ ๊ฐ๊น์ด ์ด๋ฏธ์ง๋ฅผ ๋ณต์ํ๋ค.
Forward Process
์๊ฐ t์์ ์ด๋ฏธ์ง x_t๋ ์ด์ ์ด๋ฏธ์ง x_{t-1}์ ์ฝ๊ฐ์ ๋ ธ์ด์ฆ๋ฅผ ์ถ๊ฐํ ๊ฒ์ด๋ค.
x_t = โ(1 - ฮฒt) * x_{t-1} + โฮฒt * ฮต, ฮต ~ N(0, I)
ฮฒt๋ ์คํ ๋ง๋ค์ ๋ ธ์ด์ฆ ๊ฐ๋(์ค์ผ์ค)์ด๊ณ , T ์คํ ์ด ์ง๋๋ฉด x_T๋ ๊ฑฐ์ ์์ํ ๊ฐ์ฐ์์ ๋ถํฌ๊ฐ ๋๋ค.
์์์ ์ ๊ฐํ๋ฉด ์์์ t ์คํ ์์์ x_t๋ฅผ ์๋ณธ x_0์ผ๋ก๋ถํฐ ๋ซํ ํํ(closed form)๋ก ๋ฐ๋ก ๊ณ์ฐํ ์ ์๋ค.
x_t = โแพฑt * x_0 + โ(1 - แพฑt) * ฮต
แพฑt๋ ฮฒ1~ฮฒt์ ๋์ ๊ณฑ์ผ๋ก ์ ์๋๋ค. ํ์ต ์ค์ ์์์ t๋ฅผ ์ํ๋งํด์ x_t๋ฅผ ํ ๋ฒ์ ๋ง๋ค ์ ์์ด ํจ์จ์ ์ด๋ค.
Reverse Process (U-Net์ ์ญํ )
๋ชจ๋ธ์ x_t์ ์๊ฐ t๋ฅผ ์ ๋ ฅ๋ฐ์ ํด๋น ์คํ ์์ ์ถ๊ฐ๋ ๋ ธ์ด์ฆ ฮต๋ฅผ ์์ธกํ๋๋ก ํ์ต๋๋ค. ์ค์ ๊ตฌํ์์๋ U-Net ์ํคํ ์ฒ๊ฐ ์ฃผ๋ก ์ฐ์ธ๋ค.
ํ์ต ์์ค์ ๋จ์ํ ์ค์ ๋ ธ์ด์ฆ ฮต์ ์์ธก ๋ ธ์ด์ฆ ฮต_ฮธ ์ฌ์ด์ MSE๋ค.
L = E[||ฮต - ฮต_ฮธ(x_t, t)||ยฒ]
์ถ๋ก ์์๋ ์์ ๊ฐ์ฐ์์ ๋ ธ์ด์ฆ x_T์์ ์์ํด T โ 0 ๋ฐฉํฅ์ผ๋ก ์คํ ๋ง๋ค ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ๋ฉฐ ์ต์ข ์ด๋ฏธ์ง๋ฅผ ๋ง๋ ๋ค.
Stable Diffusion์ ํต์ฌ: Latent Diffusion
ํฝ์ ๊ณต๊ฐ์์ ์ง์ Diffusion์ ์ํํ๋ฉด ๊ณ ํด์๋ ์ด๋ฏธ์ง์ผ์๋ก ์ฐ์ฐ ๋น์ฉ์ด ํญ๋ฐ์ ์ผ๋ก ์ฆ๊ฐํ๋ค. Stable Diffusion์ ์ด๋ฅผ ์ ์ฌ ๊ณต๊ฐ(latent space) ์์ ์ํํด ํด๊ฒฐํ๋ค.
์ ์ฒด ํ์ดํ๋ผ์ธ์ ์ธ ์ปดํฌ๋ํธ๋ก ๊ตฌ์ฑ๋๋ค.
| ์ปดํฌ๋ํธ | ์ญํ |
|---|---|
| VAE (Variational Autoencoder) | ์ด๋ฏธ์ง๋ฅผ ์ ์ฐจ์ ์ ์ฌ ๋ฒกํฐ๋ก ์ธ์ฝ๋ฉ / ๋ณต์ |
| U-Net (+ Attention) | ์ ์ฌ ๊ณต๊ฐ์์ Diffusion์ ๋ ธ์ด์ฆ ์์ธก |
| Text Encoder (CLIP ๋ฑ) | ํ ์คํธ ํ๋กฌํํธ๋ฅผ ์๋ฒ ๋ฉ์ผ๋ก ๋ณํํด U-Net์ ์กฐ๊ฑด ์ ๊ณต |
- ์ด๋ฏธ์ง๋ฅผ VAE ์ธ์ฝ๋๋ก ์์ถํด latent z๋ฅผ ์ป๋๋ค (์: 512ร512 โ 64ร64ร4).
- z์ Forward diffusion์ ์ ์ฉํด ๋ ธ์ด์ฆ๋ฅผ ์ถ๊ฐํ๋ค.
- U-Net์ด latent ๊ณต๊ฐ์์ ๋ ธ์ด์ฆ๋ฅผ ์์ธกํ๊ณ ์ ๊ฑฐํ๋ค.
- ๋ณต์๋ latent๋ฅผ VAE ๋์ฝ๋๋ก ๋ค์ ํฝ์ ๊ณต๊ฐ์ผ๋ก ๋ณํํ๋ค.
ํ ์คํธ ํ๋กฌํํธ๋ Cross-Attention์ ํตํด U-Net์ ๊ฐ ๋ ์ด์ด์ ์กฐ๊ฑด์ผ๋ก ์ฃผ์ ๋๋ค.
Classifier-Free Guidance (CFG)
ํ ์คํธ ์กฐ๊ฑด์ ์ผ๋ง๋ ๊ฐํ๊ฒ ๋ฐ์ํ ์ง ์กฐ์ ํ๋ ๊ธฐ๋ฒ์ด๋ค. ๊ฐ์ U-Net์ ๋ ๋ฒ ์คํํ๋ค.
- ํ ์คํธ ์กฐ๊ฑด ์์ โ ์กฐ๊ฑด๋ถ ๋ ธ์ด์ฆ ์์ธก ฮต_c
- ํ ์คํธ ์กฐ๊ฑด ์์(๋น ํ๋กฌํํธ) โ ๋ฌด์กฐ๊ฑด ๋ ธ์ด์ฆ ์์ธก ฮต_u
์ต์ข ๋ ธ์ด์ฆ ์์ธก์ ๋ ๊ฒฐ๊ณผ๋ฅผ ์ ํ ๋ณด๊ฐํ๋ค.
ฮต_final = ฮต_u + guidance_scale * (ฮต_c - ฮต_u)
guidance_scale(CFG ์ค์ผ์ผ)์ด ๋์์๋ก ํ
์คํธ์ ๋ ์ถฉ์คํ ์ด๋ฏธ์ง๊ฐ ์์ฑ๋์ง๋ง, ๋๋ฌด ๋์ผ๋ฉด ๊ณผํฌํยท์๊ณก์ด ๋ฐ์ํ๋ค. ๋ณดํต 7~12 ์ฌ์ด๋ฅผ ์ฌ์ฉํ๋ค.
์ํ๋ง ์ค์ผ์ค๋ฌ
Reverse process๋ฅผ T ์คํ ์ ๋ถ ์ํํ๋ฉด ๋๋ฆฌ๋ค. ์ค์ฉ์ ์ผ๋ก๋ ์ค์ผ์ค๋ฌ๋ฅผ ํตํด ์คํ ์๋ฅผ ์ค์ธ๋ค.
| ์ค์ผ์ค๋ฌ | ํน์ง |
|---|---|
| DDPM | ์๋ ๋ ผ๋ฌธ ๋ฐฉ์, 1000 ์คํ ํ์ |
| DDIM | ๊ฒฐ์ ๋ก ์ ์ํ๋ง, 20~50 ์คํ ์ผ๋ก ๊ฐ๋ฅ |
| DPM++ | ๊ณ ์ฐจ ์๋ฒ, 15~25 ์คํ ์์ ์ข์ ํ์ง |
| Euler / Euler A | ๋น ๋ฅด๊ณ ์์ ์ , ์ค์ ๋ก ๋ง์ด ์ฌ์ฉ |
DDIM์ ๋์ผํ ๋ ธ์ด์ฆ ์๋์์ ์ผ๊ด๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ธฐ ๋๋ฌธ์ img2img๋ inpainting์๋ ์์ฃผ ์ฐ์ธ๋ค.
img2img์ Inpainting
img2img: ์ ๋ ฅ ์ด๋ฏธ์ง๋ฅผ ์ผ๋ถ ๋ ธ์ด์ฆํ(denoise_strength๋ก ์กฐ์ )ํ ๋ค Reverse process๋ฅผ ์คํํ๋ค. strength๊ฐ ๋ฎ์์๋ก ์๋ณธ์ ๊ฐ๊น์ด ์ด๋ฏธ์ง๊ฐ ๋์จ๋ค.
Inpainting: ๋ง์คํฌ ์์ญ๋ง ๋ ธ์ด์ฆํํ๊ณ ๋๋จธ์ง๋ ์๋ณธ์ ์ ์งํ ์ฑ Reverse process๋ฅผ ์คํํ๋ค. ํน์ ๋ถ๋ถ๋ง ์์ ํ ๋ ์ฌ์ฉํ๋ค.
GAN๊ณผ ๋น๊ต
| ํญ๋ชฉ | GAN | Diffusion |
|---|---|---|
| ํ์ต ์์ ์ฑ | ๋ถ์์ (mode collapse) | ์์ ์ |
| ์์ฑ ํ์ง | ๋ ์นด๋กญ์ง๋ง ๋ค์์ฑ ๋ถ์กฑ | ๋์ ๋ค์์ฑ |
| ์๋ | ๋น ๋ฆ (๋จ์ผ forward pass) | ๋๋ฆผ (๋ค๋จ๊ณ reverse) |
| ์กฐ๊ฑด๋ถ ์์ฑ | ์ถ๊ฐ ๊ตฌ์กฐ ํ์ | CFG๋ก ์์ฐ์ค๋ฝ๊ฒ ํตํฉ |
Diffusion ๋ชจ๋ธ์ด ์ด๋ฏธ์ง ์์ฑ ํ์ง๊ณผ ํ ์คํธ ์กฐ๊ฑด ์ ์ด ๋ฉด์์ GAN์ ๋๋ถ๋ถ์ ํ์คํฌ์์ ์์ง๋ฅด๋ฉด์ ์ฃผ๋ฅ๊ฐ ๋๋ค.
์ ๋ฆฌ
- Diffusion ๋ชจ๋ธ์ Forward(๋ ธ์ด์ฆ ์ถ๊ฐ) + Reverse(๋ ธ์ด์ฆ ์ ๊ฑฐ)๋ฅผ ํ์ตํ๋ ์์ฑ ๋ชจ๋ธ์ด๋ค.
- Stable Diffusion์ VAE๋ก ํฝ์ ์ latent๋ก ์์ถํ ๋ค latent ๊ณต๊ฐ์์ Diffusion์ ์ํํด ํจ์จ์ ๋์ธ๋ค.
- CLIP ๊ธฐ๋ฐ ํ ์คํธ ์ธ์ฝ๋์ Cross-Attention์ผ๋ก ํ ์คํธ ์กฐ๊ฑด๋ถ ์์ฑ์ ๊ตฌํํ๋ค.
- CFG ์ค์ผ์ผ๋ก ํ ์คํธ ์ถฉ์ค๋๋ฅผ ์กฐ์ ํ๊ณ , DDIM/DPM++ ๊ฐ์ ์ค์ผ์ค๋ฌ๋ก ์ํ๋ง ์๋๋ฅผ ๋์ธ๋ค.