[Daily morning study] Diffusion ๋ชจ๋ธ ์›๋ฆฌ (Stable Diffusion)

#daily morning study

Image


Diffusion ๋ชจ๋ธ์ด๋ž€

Diffusion ๋ชจ๋ธ์€ ์ด๋ฏธ์ง€๋ฅผ ์ ์ง„์ ์œผ๋กœ ๋…ธ์ด์ฆˆ๋กœ ๋งŒ๋“  ๋‹ค์Œ, ์—ญ๋ฐฉํ–ฅ์œผ๋กœ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜๋Š” ๊ณผ์ •์„ ํ•™์Šตํ•ด์„œ ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ์ƒ์„ฑ ๋ชจ๋ธ์ด๋‹ค. GAN์ด Generator์™€ Discriminator์˜ ๊ฒฝ์Ÿ์œผ๋กœ ํ•™์Šตํ•˜๋Š” ๊ฒƒ๊ณผ ๋‹ฌ๋ฆฌ, Diffusion์€ ๋…ธ์ด์ฆˆ ์˜ˆ์ธก์ด๋ผ๋Š” ๋‹จ์ˆœํ•œ ๋ชฉํ‘œ ํ•˜๋‚˜๋กœ ํ•™์Šตํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ•™์Šต์ด ์•ˆ์ •์ ์ด๊ณ  ์ƒ์„ฑ ํ’ˆ์งˆ๋„ ๋†’๋‹ค.


Forward Process (๋…ธ์ด์ฆˆ ์ถ”๊ฐ€)

์›๋ณธ ์ด๋ฏธ์ง€ xโ‚€์— T ์Šคํ…์— ๊ฑธ์ณ ๊ฐ€์šฐ์‹œ์•ˆ ๋…ธ์ด์ฆˆ๋ฅผ ์กฐ๊ธˆ์”ฉ ์ถ”๊ฐ€ํ•ด์„œ, ์ตœ์ข…์ ์œผ๋กœ ์™„์ „ํ•œ ๋…ธ์ด์ฆˆ xT๋กœ ๋งŒ๋“œ๋Š” ๊ณผ์ •์ด๋‹ค.

๊ฐ ์Šคํ…์˜ ์ˆ˜์‹:

q(x_t | x_{t-1}) = N(x_t; โˆš(1-ฮฒ_t) * x_{t-1}, ฮฒ_t * I)
  • ฮฒ_t: ๊ฐ ์Šคํ…์—์„œ ์ถ”๊ฐ€ํ•  ๋…ธ์ด์ฆˆ์˜ ๊ฐ•๋„ (noise schedule)
  • ฮฒ_t๊ฐ€ ์ž‘์œผ๋ฉด ์กฐ๊ธˆ์”ฉ, ํฌ๋ฉด ๋งŽ์ด ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€
  • ์ผ๋ฐ˜์ ์œผ๋กœ T=1000 ์Šคํ… ์‚ฌ์šฉ

xโ‚€์—์„œ ์ž„์˜์˜ ์Šคํ… t๋กœ ๋ฐ”๋กœ ์ ํ”„ํ•˜๋Š” ๊ฒƒ๋„ ๊ฐ€๋Šฅํ•˜๋‹ค:

q(x_t | x_0) = N(x_t; โˆšแพฑ_t * x_0, (1-แพฑ_t) * I)

์—ฌ๊ธฐ์„œ แพฑ_t = โˆ(1-ฮฒ_s) (s=1๋ถ€ํ„ฐ t๊นŒ์ง€์˜ ๋ˆ„์ ๊ณฑ). ๋•๋ถ„์— ํ•™์Šต ์‹œ ์ž„์˜์˜ t ์Šคํ… ์ƒ˜ํ”Œ์„ ํ•œ ๋ฒˆ์— ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค.


Reverse Process (๋…ธ์ด์ฆˆ ์ œ๊ฑฐ)

์™„์ „ํ•œ ๋…ธ์ด์ฆˆ xT์—์„œ ์ถœ๋ฐœํ•ด์„œ ๋‹จ๊ณ„์ ์œผ๋กœ ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜๋ฉฐ ์ด๋ฏธ์ง€๋ฅผ ๋ณต์›ํ•˜๋Š” ๊ณผ์ •์ด๋‹ค. ๊ฐ ์Šคํ…์—์„œ ์‹ ๊ฒฝ๋ง(U-Net)์ด ํ˜„์žฌ ๋…ธ์ด์ฆˆ ์ด๋ฏธ์ง€ x_t์™€ ์‹œ๊ฐ„ ์Šคํ… t๋ฅผ ๋ฐ›์•„ ์ถ”๊ฐ€๋œ ๋…ธ์ด์ฆˆ ฮต๋ฅผ ์˜ˆ์ธกํ•œ๋‹ค.

p_ฮธ(x_{t-1} | x_t) = N(x_{t-1}; ฮผ_ฮธ(x_t, t), ฮฃ_ฮธ(x_t, t))

ํ•™์Šต ๋ชฉํ‘œ

์‹ค์ œ๋กœ ์ถ”๊ฐ€๋œ ๋…ธ์ด์ฆˆ ฮต์™€ ์‹ ๊ฒฝ๋ง์ด ์˜ˆ์ธกํ•œ ๋…ธ์ด์ฆˆ ฮต_ฮธ ์‚ฌ์ด์˜ MSE๋ฅผ ์ตœ์†Œํ™”ํ•œ๋‹ค.

L = E[||ฮต - ฮต_ฮธ(x_t, t)||ยฒ]

ํ•™์Šต ํ๋ฆ„:

  1. ์›๋ณธ ์ด๋ฏธ์ง€ xโ‚€ ์ƒ˜ํ”Œ๋ง
  2. ๋žœ๋ค ํƒ€์ž„์Šคํ… t, ๊ฐ€์šฐ์‹œ์•ˆ ๋…ธ์ด์ฆˆ ฮต ์ƒ˜ํ”Œ๋ง
  3. x_t = โˆšแพฑ_t * xโ‚€ + โˆš(1-แพฑ_t) * ฮต ๊ณ„์‚ฐ
  4. U-Net์— x_t, t๋ฅผ ์ž…๋ ฅํ•ด ฮต_ฮธ ์˜ˆ์ธก
  5. L =ย ฮต - ฮต_ฮธย ยฒ ๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ

Stable Diffusion ์•„ํ‚คํ…์ฒ˜

๊ธฐ๋ณธ Diffusion ๋ชจ๋ธ์€ pixel space์—์„œ ๋™์ž‘ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ณ ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€์—์„œ ๊ณ„์‚ฐ ๋น„์šฉ์ด ๋งค์šฐ ํฌ๋‹ค. Stable Diffusion์€ ์ด๋ฅผ latent space์—์„œ ์ˆ˜ํ–‰ํ•˜๋„๋ก ๊ฐœ์„ ํ•œ Latent Diffusion Model (LDM) ์ด๋‹ค.

4๊ฐ€์ง€ ํ•ต์‹ฌ ์ปดํฌ๋„ŒํŠธ

1. VAE (Variational Autoencoder)

์ด๋ฏธ์ง€๋ฅผ ์••์ถ•๋œ latent space๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ๋ณต์›ํ•˜๋Š” ์—ญํ• ์ด๋‹ค.

  • Encoder: ์ด๋ฏธ์ง€(512ร—512ร—3) โ†’ latent(64ร—64ร—4), 8๋ฐฐ ์••์ถ•
  • Decoder: latent(64ร—64ร—4) โ†’ ์ด๋ฏธ์ง€(512ร—512ร—3)
  • Diffusion์€ ์ด latent space์—์„œ๋งŒ ๋™์ž‘ํ•˜๋ฏ€๋กœ ๊ณ„์‚ฐ๋Ÿ‰์ด ๋Œ€ํญ ์ค„์–ด๋“ ๋‹ค

2. U-Net

latent space์—์„œ ๊ฐ ์Šคํ…์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๋ฉ”์ธ ์‹ ๊ฒฝ๋ง์ด๋‹ค.

  • Encoder + Bottleneck + Decoder ๊ตฌ์กฐ, Skip Connection์œผ๋กœ ์—ฐ๊ฒฐ
  • ResBlock: ๊ฐ ๋ธ”๋ก์—์„œ ์ž”์ฐจ ์—ฐ๊ฒฐ๋กœ ํ•™์Šต ์•ˆ์ •ํ™”
  • Attention Layer: Self-Attention๊ณผ Cross-Attention ํฌํ•จ
  • Time Embedding: ํ˜„์žฌ ์Šคํ… t๋ฅผ ์‚ฌ์ธํŒŒ๋กœ ์ธ์ฝ”๋”ฉํ•ด ๋„คํŠธ์›Œํฌ์— ์ „๋‹ฌ

3. CLIP Text Encoder

ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ๋ชจ๋“ˆ์ด๋‹ค.

  • OpenAI CLIP ๋ชจ๋ธ์˜ ํ…์ŠคํŠธ ์ธ์ฝ”๋” ํ™œ์šฉ
  • ํ† ํฐํ™” โ†’ transformer ์ธ์ฝ”๋”ฉ โ†’ text embedding
  • ์ด embedding์ด U-Net์˜ Cross-Attention์— condition์œผ๋กœ ๋“ค์–ด๊ฐ„๋‹ค

4. Noise Scheduler

๊ฐ ์Šคํ…์˜ ๋…ธ์ด์ฆˆ ์ถ”๊ฐ€/์ œ๊ฑฐ ๋น„์œจ์„ ๊ด€๋ฆฌํ•œ๋‹ค.

  • DDPM: ์›๋ž˜ 1000 ์Šคํ…, Markovian ํ™•๋ฅ  ์ƒ˜ํ”Œ๋ง
  • DDIM: 20~50 ์Šคํ…์œผ๋กœ ์ค„์ธ ๊ฒฐ์ •๋ก ์  ์ƒ˜ํ”Œ๋ง (๊ฐ™์€ seed โ†’ ๊ฐ™์€ ์ด๋ฏธ์ง€)
  • DPM-Solver: ๋” ๋น ๋ฅธ ๊ณ ์ฐจ ODE ๊ธฐ๋ฐ˜ ์†”๋ฒ„

์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ๋ฆ„

ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ
      โ†“
CLIP Text Encoder
      โ†“
text embedding โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
                                                     โ†“
๋žœ๋ค ๊ฐ€์šฐ์‹œ์•ˆ ๋…ธ์ด์ฆˆ (latent 64ร—64ร—4) โ†’ U-Net (Cross-Attention) โ†’ denoised latent
                                              โ†‘ (N๋ฒˆ ๋ฐ˜๋ณต)
                                         Noise Scheduler
                                                     โ†“
                                            VAE Decoder
                                                     โ†“
                                          ์ตœ์ข… ์ด๋ฏธ์ง€ (512ร—512)

CFG (Classifier-Free Guidance)

ํ”„๋กฌํ”„ํŠธ๋ฅผ ์–ผ๋งˆ๋‚˜ ๊ฐ•ํ•˜๊ฒŒ ๋ฐ˜์˜ํ• ์ง€ ์กฐ์ ˆํ•˜๋Š” ํ•ต์‹ฌ ๊ธฐ๋ฒ•์ด๋‹ค.

ํ•™์Šต ๋‹จ๊ณ„: ๋™์ผํ•œ U-Net์ด ์กฐ๊ฑด๋ถ€(ํ…์ŠคํŠธ ์žˆ์Œ)์™€ ๋น„์กฐ๊ฑด๋ถ€(ํ…์ŠคํŠธ ์—†์Œ, null ํ† ํฐ) ์˜ˆ์ธก์„ ๋ชจ๋‘ ํ•™์Šตํ•œ๋‹ค.

์ถ”๋ก  ๋‹จ๊ณ„: ๋‘ ์˜ˆ์ธก์„ ์กฐํ•ฉํ•ด์„œ ์ตœ์ข… ๋…ธ์ด์ฆˆ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.

noise_pred = uncond_pred + guidance_scale ร— (cond_pred - uncond_pred)
  • guidance_scale = 1: ๋น„์กฐ๊ฑด๋ถ€ ์˜ˆ์ธก๋งŒ ์‚ฌ์šฉ (ํ”„๋กฌํ”„ํŠธ ๋ฌด์‹œ)
  • guidance_scale = 7.5: ์ผ๋ฐ˜์ ์œผ๋กœ ๋งŽ์ด ์“ฐ๋Š” ๊ฐ’
  • guidance_scale์ด ๋†’์„์ˆ˜๋ก ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋” ๊ฐ•ํ•˜๊ฒŒ ๋”ฐ๋ฅด์ง€๋งŒ ๋‹ค์–‘์„ฑ์€ ๊ฐ์†Œ

DDIM Sampling

DDPM์˜ ๋А๋ฆฐ ์ƒ˜ํ”Œ๋ง ์†๋„๋ฅผ ํ•ด๊ฒฐํ•œ ๋ฐฉ๋ฒ•์ด๋‹ค.

ย DDPMDDIM
์ƒ˜ํ”Œ๋ง ๋ฐฉ์‹Markovian (ํ™•๋ฅ ์ )non-Markovian (๊ฒฐ์ •๋ก ์ )
ํ•„์š” ์Šคํ…100020~50
์žฌํ˜„์„ฑ๊ฐ™์€ seed์—ฌ๋„ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์Œ๊ฐ™์€ seed โ†’ ํ•ญ์ƒ ๊ฐ™์€ ์ด๋ฏธ์ง€
์†๋„๋А๋ฆผ๋น ๋ฆ„

DDIM์€ U-Net ๊ฐ€์ค‘์น˜๋ฅผ ๊ทธ๋Œ€๋กœ ์“ฐ๋ฉด์„œ ์ƒ˜ํ”Œ๋ง ๊ณต์‹๋งŒ ๋ฐ”๊พผ ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์—, ์ถ”๊ฐ€ ํ•™์Šต ์—†์ด ์†๋„๋ฅผ ํฌ๊ฒŒ ๊ฐœ์„ ํ•  ์ˆ˜ ์žˆ๋‹ค.


ControlNet

ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ์™ธ์— ์ถ”๊ฐ€ ์ด๋ฏธ์ง€ ์กฐ๊ฑด(pose, edge, depth ๋“ฑ)์œผ๋กœ ์ƒ์„ฑ์„ ๋” ์„ธ๋ฐ€ํ•˜๊ฒŒ ์ œ์–ดํ•˜๋Š” ๊ธฐ๋ฒ•์ด๋‹ค.

์ž…๋ ฅ: [ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ] + [์ œ์–ด ์ด๋ฏธ์ง€ (์˜ˆ: ํฌ์ฆˆ ์Šค์ผˆ๋ ˆํ†ค)]
      โ†“
ControlNet (U-Net ๋ณต์‚ฌ๋ณธ, ํ•™์Šต ๊ฐ€๋Šฅ) + ์›๋ณธ U-Net (๋™๊ฒฐ)
      โ†“
์ œ์–ด ์‹ ํ˜ธ๋ฅผ U-Net ๋””์ฝ”๋”์— ์ฃผ์ž…
      โ†“
์ œ์–ด๋œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ

์›๋ณธ U-Net ๊ฐ€์ค‘์น˜๋Š” ๊ณ ์ •(freeze)ํ•˜๊ณ , ๊ทธ ๋ณต์‚ฌ๋ณธ์„ ์ œ์–ด ์กฐ๊ฑด์œผ๋กœ ํ•™์Šต์‹œํ‚จ๋‹ค. ๋•๋ถ„์— ๊ธฐ์กด ๋ชจ๋ธ ํ’ˆ์งˆ์„ ์œ ์ง€ํ•˜๋ฉด์„œ ํฌ์ฆˆยท๊ตฌ๋„ยท์„ ํ™”๋ฅผ ์ •ํ™•ํžˆ ๋”ฐ๋ฅด๋Š” ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค.


GAN vs Diffusion ๋น„๊ต

ํŠน์ง•GANDiffusion
ํ•™์Šต ์•ˆ์ •์„ฑmode collapse ๋ฌธ์ œ ์žˆ์Œ์•ˆ์ •์ 
์ƒ์„ฑ ํ’ˆ์งˆ๋†’์Œ๋” ๋†’์Œ
์ƒ˜ํ”Œ๋ง ์†๋„๋น ๋ฆ„ (๋‹จ์ผ forward pass)๋А๋ฆผ (iterative)
์ถœ๋ ฅ ๋‹ค์–‘์„ฑ์ œํ•œ์ ๋†’์Œ
ํ…์ŠคํŠธ ์กฐ๊ฑด ์ œ์–ด์–ด๋ ค์›€CFG๋กœ ์‰ฝ๊ฒŒ ์ œ์–ด

์ •๋ฆฌ

Diffusion ๋ชจ๋ธ์˜ ํ•ต์‹ฌ์€ โ€œ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๊ณผ์ •์„ ๋’ค์ง‘๋Š” ๊ฒƒโ€ ์ด๋‹ค. Forward process๋Š” ์ˆ˜์‹์œผ๋กœ ์ •์˜๋œ ๊ณ ์ • ๊ณผ์ •์ด๊ณ , Reverse process๋งŒ ์‹ ๊ฒฝ๋ง์ด ํ•™์Šตํ•œ๋‹ค. Stable Diffusion์€ ์—ฌ๊ธฐ์— VAE๋กœ latent ์••์ถ•, CLIP์œผ๋กœ ํ…์ŠคํŠธ ์กฐ๊ฑด, CFG๋กœ ๊ฐ€์ด๋˜์Šค๋ฅผ ์ถ”๊ฐ€ํ•ด์„œ ์‹ค์šฉ์ ์ธ text-to-image ์ƒ์„ฑ ์‹œ์Šคํ…œ์„ ๊ตฌ์ถ•ํ•œ ๊ฒƒ์ด๋‹ค.

ํ•™์Šต ๋•Œ๋Š” ๋‹จ์ˆœํžˆ ๋…ธ์ด์ฆˆ ์˜ˆ์ธก MSE๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š”๋ฐ, ์ถ”๋ก  ๋•Œ๋Š” ๊ทธ U-Net์„ ์ˆ˜์‹ญ ๋ฒˆ ๋ฐ˜๋ณต ํ˜ธ์ถœํ•ด์„œ ์ ์ง„์ ์œผ๋กœ ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“ค์–ด ๋‚ธ๋‹ค๋Š” ์ ์ด GAN๊ณผ ๊ฐ€์žฅ ํฐ ๊ตฌ์กฐ์  ์ฐจ์ด๋‹ค.