[Daily morning study] Diffusion ๋ชจ๋ธ ์๋ฆฌ (Stable Diffusion)
#daily morning study
Diffusion ๋ชจ๋ธ์ด๋
Diffusion ๋ชจ๋ธ์ ์ด๋ฏธ์ง๋ฅผ ์ ์ง์ ์ผ๋ก ๋ ธ์ด์ฆ๋ก ๋ง๋ ๋ค์, ์ญ๋ฐฉํฅ์ผ๋ก ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ๋ ๊ณผ์ ์ ํ์ตํด์ ์๋ก์ด ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ์์ฑ ๋ชจ๋ธ์ด๋ค. GAN์ด Generator์ Discriminator์ ๊ฒฝ์์ผ๋ก ํ์ตํ๋ ๊ฒ๊ณผ ๋ฌ๋ฆฌ, Diffusion์ ๋ ธ์ด์ฆ ์์ธก์ด๋ผ๋ ๋จ์ํ ๋ชฉํ ํ๋๋ก ํ์ตํ๊ธฐ ๋๋ฌธ์ ํ์ต์ด ์์ ์ ์ด๊ณ ์์ฑ ํ์ง๋ ๋๋ค.
Forward Process (๋ ธ์ด์ฆ ์ถ๊ฐ)
์๋ณธ ์ด๋ฏธ์ง xโ์ T ์คํ ์ ๊ฑธ์ณ ๊ฐ์ฐ์์ ๋ ธ์ด์ฆ๋ฅผ ์กฐ๊ธ์ฉ ์ถ๊ฐํด์, ์ต์ข ์ ์ผ๋ก ์์ ํ ๋ ธ์ด์ฆ xT๋ก ๋ง๋๋ ๊ณผ์ ์ด๋ค.
๊ฐ ์คํ ์ ์์:
q(x_t | x_{t-1}) = N(x_t; โ(1-ฮฒ_t) * x_{t-1}, ฮฒ_t * I)
ฮฒ_t: ๊ฐ ์คํ ์์ ์ถ๊ฐํ ๋ ธ์ด์ฆ์ ๊ฐ๋ (noise schedule)- ฮฒ_t๊ฐ ์์ผ๋ฉด ์กฐ๊ธ์ฉ, ํฌ๋ฉด ๋ง์ด ๋ ธ์ด์ฆ๋ฅผ ์ถ๊ฐ
- ์ผ๋ฐ์ ์ผ๋ก T=1000 ์คํ ์ฌ์ฉ
xโ์์ ์์์ ์คํ t๋ก ๋ฐ๋ก ์ ํํ๋ ๊ฒ๋ ๊ฐ๋ฅํ๋ค:
q(x_t | x_0) = N(x_t; โแพฑ_t * x_0, (1-แพฑ_t) * I)
์ฌ๊ธฐ์ แพฑ_t = โ(1-ฮฒ_s) (s=1๋ถํฐ t๊น์ง์ ๋์ ๊ณฑ). ๋๋ถ์ ํ์ต ์ ์์์ t ์คํ
์ํ์ ํ ๋ฒ์ ๋ง๋ค ์ ์๋ค.
Reverse Process (๋ ธ์ด์ฆ ์ ๊ฑฐ)
์์ ํ ๋ ธ์ด์ฆ xT์์ ์ถ๋ฐํด์ ๋จ๊ณ์ ์ผ๋ก ๋ ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ๋ฉฐ ์ด๋ฏธ์ง๋ฅผ ๋ณต์ํ๋ ๊ณผ์ ์ด๋ค. ๊ฐ ์คํ ์์ ์ ๊ฒฝ๋ง(U-Net)์ด ํ์ฌ ๋ ธ์ด์ฆ ์ด๋ฏธ์ง x_t์ ์๊ฐ ์คํ t๋ฅผ ๋ฐ์ ์ถ๊ฐ๋ ๋ ธ์ด์ฆ ฮต๋ฅผ ์์ธกํ๋ค.
p_ฮธ(x_{t-1} | x_t) = N(x_{t-1}; ฮผ_ฮธ(x_t, t), ฮฃ_ฮธ(x_t, t))
ํ์ต ๋ชฉํ
์ค์ ๋ก ์ถ๊ฐ๋ ๋ ธ์ด์ฆ ฮต์ ์ ๊ฒฝ๋ง์ด ์์ธกํ ๋ ธ์ด์ฆ ฮต_ฮธ ์ฌ์ด์ MSE๋ฅผ ์ต์ํํ๋ค.
L = E[||ฮต - ฮต_ฮธ(x_t, t)||ยฒ]
ํ์ต ํ๋ฆ:
- ์๋ณธ ์ด๋ฏธ์ง xโ ์ํ๋ง
- ๋๋ค ํ์์คํ t, ๊ฐ์ฐ์์ ๋ ธ์ด์ฆ ฮต ์ํ๋ง
- x_t = โแพฑ_t * xโ + โ(1-แพฑ_t) * ฮต ๊ณ์ฐ
- U-Net์ x_t, t๋ฅผ ์ ๋ ฅํด ฮต_ฮธ ์์ธก
L = ย ฮต - ฮต_ฮธ ย ยฒ ๋ก ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธ
Stable Diffusion ์ํคํ ์ฒ
๊ธฐ๋ณธ Diffusion ๋ชจ๋ธ์ pixel space์์ ๋์ํ๊ธฐ ๋๋ฌธ์ ๊ณ ํด์๋ ์ด๋ฏธ์ง์์ ๊ณ์ฐ ๋น์ฉ์ด ๋งค์ฐ ํฌ๋ค. Stable Diffusion์ ์ด๋ฅผ latent space์์ ์ํํ๋๋ก ๊ฐ์ ํ Latent Diffusion Model (LDM) ์ด๋ค.
4๊ฐ์ง ํต์ฌ ์ปดํฌ๋ํธ
1. VAE (Variational Autoencoder)
์ด๋ฏธ์ง๋ฅผ ์์ถ๋ latent space๋ก ๋ณํํ๊ณ ๋ณต์ํ๋ ์ญํ ์ด๋ค.
- Encoder: ์ด๋ฏธ์ง(512ร512ร3) โ latent(64ร64ร4), 8๋ฐฐ ์์ถ
- Decoder: latent(64ร64ร4) โ ์ด๋ฏธ์ง(512ร512ร3)
- Diffusion์ ์ด latent space์์๋ง ๋์ํ๋ฏ๋ก ๊ณ์ฐ๋์ด ๋ํญ ์ค์ด๋ ๋ค
2. U-Net
latent space์์ ๊ฐ ์คํ ์ ๋ ธ์ด์ฆ๋ฅผ ์์ธกํ๋ ๋ฉ์ธ ์ ๊ฒฝ๋ง์ด๋ค.
- Encoder + Bottleneck + Decoder ๊ตฌ์กฐ, Skip Connection์ผ๋ก ์ฐ๊ฒฐ
- ResBlock: ๊ฐ ๋ธ๋ก์์ ์์ฐจ ์ฐ๊ฒฐ๋ก ํ์ต ์์ ํ
- Attention Layer: Self-Attention๊ณผ Cross-Attention ํฌํจ
- Time Embedding: ํ์ฌ ์คํ t๋ฅผ ์ฌ์ธํ๋ก ์ธ์ฝ๋ฉํด ๋คํธ์ํฌ์ ์ ๋ฌ
3. CLIP Text Encoder
ํ ์คํธ ํ๋กฌํํธ๋ฅผ ๋ฒกํฐ๋ก ๋ณํํ๋ ๋ชจ๋์ด๋ค.
- OpenAI CLIP ๋ชจ๋ธ์ ํ ์คํธ ์ธ์ฝ๋ ํ์ฉ
- ํ ํฐํ โ transformer ์ธ์ฝ๋ฉ โ text embedding
- ์ด embedding์ด U-Net์ Cross-Attention์ condition์ผ๋ก ๋ค์ด๊ฐ๋ค
4. Noise Scheduler
๊ฐ ์คํ ์ ๋ ธ์ด์ฆ ์ถ๊ฐ/์ ๊ฑฐ ๋น์จ์ ๊ด๋ฆฌํ๋ค.
- DDPM: ์๋ 1000 ์คํ , Markovian ํ๋ฅ ์ํ๋ง
- DDIM: 20~50 ์คํ ์ผ๋ก ์ค์ธ ๊ฒฐ์ ๋ก ์ ์ํ๋ง (๊ฐ์ seed โ ๊ฐ์ ์ด๋ฏธ์ง)
- DPM-Solver: ๋ ๋น ๋ฅธ ๊ณ ์ฐจ ODE ๊ธฐ๋ฐ ์๋ฒ
์ด๋ฏธ์ง ์์ฑ ํ๋ฆ
ํ
์คํธ ํ๋กฌํํธ
โ
CLIP Text Encoder
โ
text embedding โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
๋๋ค ๊ฐ์ฐ์์ ๋
ธ์ด์ฆ (latent 64ร64ร4) โ U-Net (Cross-Attention) โ denoised latent
โ (N๋ฒ ๋ฐ๋ณต)
Noise Scheduler
โ
VAE Decoder
โ
์ต์ข
์ด๋ฏธ์ง (512ร512)
CFG (Classifier-Free Guidance)
ํ๋กฌํํธ๋ฅผ ์ผ๋ง๋ ๊ฐํ๊ฒ ๋ฐ์ํ ์ง ์กฐ์ ํ๋ ํต์ฌ ๊ธฐ๋ฒ์ด๋ค.
ํ์ต ๋จ๊ณ: ๋์ผํ U-Net์ด ์กฐ๊ฑด๋ถ(ํ ์คํธ ์์)์ ๋น์กฐ๊ฑด๋ถ(ํ ์คํธ ์์, null ํ ํฐ) ์์ธก์ ๋ชจ๋ ํ์ตํ๋ค.
์ถ๋ก ๋จ๊ณ: ๋ ์์ธก์ ์กฐํฉํด์ ์ต์ข ๋ ธ์ด์ฆ๋ฅผ ๊ณ์ฐํ๋ค.
noise_pred = uncond_pred + guidance_scale ร (cond_pred - uncond_pred)
- guidance_scale = 1: ๋น์กฐ๊ฑด๋ถ ์์ธก๋ง ์ฌ์ฉ (ํ๋กฌํํธ ๋ฌด์)
- guidance_scale = 7.5: ์ผ๋ฐ์ ์ผ๋ก ๋ง์ด ์ฐ๋ ๊ฐ
- guidance_scale์ด ๋์์๋ก ํ๋กฌํํธ๋ฅผ ๋ ๊ฐํ๊ฒ ๋ฐ๋ฅด์ง๋ง ๋ค์์ฑ์ ๊ฐ์
DDIM Sampling
DDPM์ ๋๋ฆฐ ์ํ๋ง ์๋๋ฅผ ํด๊ฒฐํ ๋ฐฉ๋ฒ์ด๋ค.
| ย | DDPM | DDIM |
|---|---|---|
| ์ํ๋ง ๋ฐฉ์ | Markovian (ํ๋ฅ ์ ) | non-Markovian (๊ฒฐ์ ๋ก ์ ) |
| ํ์ ์คํ | 1000 | 20~50 |
| ์ฌํ์ฑ | ๊ฐ์ seed์ฌ๋ ๋ค๋ฅผ ์ ์์ | ๊ฐ์ seed โ ํญ์ ๊ฐ์ ์ด๋ฏธ์ง |
| ์๋ | ๋๋ฆผ | ๋น ๋ฆ |
DDIM์ U-Net ๊ฐ์ค์น๋ฅผ ๊ทธ๋๋ก ์ฐ๋ฉด์ ์ํ๋ง ๊ณต์๋ง ๋ฐ๊พผ ๊ฒ์ด๊ธฐ ๋๋ฌธ์, ์ถ๊ฐ ํ์ต ์์ด ์๋๋ฅผ ํฌ๊ฒ ๊ฐ์ ํ ์ ์๋ค.
ControlNet
ํ ์คํธ ํ๋กฌํํธ ์ธ์ ์ถ๊ฐ ์ด๋ฏธ์ง ์กฐ๊ฑด(pose, edge, depth ๋ฑ)์ผ๋ก ์์ฑ์ ๋ ์ธ๋ฐํ๊ฒ ์ ์ดํ๋ ๊ธฐ๋ฒ์ด๋ค.
์
๋ ฅ: [ํ
์คํธ ํ๋กฌํํธ] + [์ ์ด ์ด๋ฏธ์ง (์: ํฌ์ฆ ์ค์ผ๋ ํค)]
โ
ControlNet (U-Net ๋ณต์ฌ๋ณธ, ํ์ต ๊ฐ๋ฅ) + ์๋ณธ U-Net (๋๊ฒฐ)
โ
์ ์ด ์ ํธ๋ฅผ U-Net ๋์ฝ๋์ ์ฃผ์
โ
์ ์ด๋ ์ด๋ฏธ์ง ์์ฑ
์๋ณธ U-Net ๊ฐ์ค์น๋ ๊ณ ์ (freeze)ํ๊ณ , ๊ทธ ๋ณต์ฌ๋ณธ์ ์ ์ด ์กฐ๊ฑด์ผ๋ก ํ์ต์ํจ๋ค. ๋๋ถ์ ๊ธฐ์กด ๋ชจ๋ธ ํ์ง์ ์ ์งํ๋ฉด์ ํฌ์ฆยท๊ตฌ๋ยท์ ํ๋ฅผ ์ ํํ ๋ฐ๋ฅด๋ ์ด๋ฏธ์ง๋ฅผ ๋ง๋ค ์ ์๋ค.
GAN vs Diffusion ๋น๊ต
| ํน์ง | GAN | Diffusion |
|---|---|---|
| ํ์ต ์์ ์ฑ | mode collapse ๋ฌธ์ ์์ | ์์ ์ |
| ์์ฑ ํ์ง | ๋์ | ๋ ๋์ |
| ์ํ๋ง ์๋ | ๋น ๋ฆ (๋จ์ผ forward pass) | ๋๋ฆผ (iterative) |
| ์ถ๋ ฅ ๋ค์์ฑ | ์ ํ์ | ๋์ |
| ํ ์คํธ ์กฐ๊ฑด ์ ์ด | ์ด๋ ค์ | CFG๋ก ์ฝ๊ฒ ์ ์ด |
์ ๋ฆฌ
Diffusion ๋ชจ๋ธ์ ํต์ฌ์ โ๋ ธ์ด์ฆ๋ฅผ ์ถ๊ฐํ๋ ๊ณผ์ ์ ๋ค์ง๋ ๊ฒโ ์ด๋ค. Forward process๋ ์์์ผ๋ก ์ ์๋ ๊ณ ์ ๊ณผ์ ์ด๊ณ , Reverse process๋ง ์ ๊ฒฝ๋ง์ด ํ์ตํ๋ค. Stable Diffusion์ ์ฌ๊ธฐ์ VAE๋ก latent ์์ถ, CLIP์ผ๋ก ํ ์คํธ ์กฐ๊ฑด, CFG๋ก ๊ฐ์ด๋์ค๋ฅผ ์ถ๊ฐํด์ ์ค์ฉ์ ์ธ text-to-image ์์ฑ ์์คํ ์ ๊ตฌ์ถํ ๊ฒ์ด๋ค.
ํ์ต ๋๋ ๋จ์ํ ๋ ธ์ด์ฆ ์์ธก MSE๋ฅผ ์ต์ํํ๋๋ฐ, ์ถ๋ก ๋๋ ๊ทธ U-Net์ ์์ญ ๋ฒ ๋ฐ๋ณต ํธ์ถํด์ ์ ์ง์ ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ๋ง๋ค์ด ๋ธ๋ค๋ ์ ์ด GAN๊ณผ ๊ฐ์ฅ ํฐ ๊ตฌ์กฐ์ ์ฐจ์ด๋ค.