[Daily morning study] Diffusion ๋ชจ๋ธ ์›๋ฆฌ (Stable Diffusion)

#daily morning study

Image


Diffusion ๋ชจ๋ธ์ด๋ž€

Diffusion ๋ชจ๋ธ์€ ์ด๋ฏธ์ง€๋ฅผ ์ ์ง„์ ์œผ๋กœ ๋…ธ์ด์ฆˆ๋กœ ๋ง๊ฐ€๋œจ๋ฆฐ ๋’ค, ๊ทธ ์—ญ๊ณผ์ •์„ ํ•™์Šต์‹œ์ผœ ๋…ธ์ด์ฆˆ์—์„œ ์›๋ณธ ์ด๋ฏธ์ง€๋ฅผ ๋ณต์›ํ•˜๋Š” ์ƒ์„ฑ ๋ชจ๋ธ์ด๋‹ค. Stable Diffusion, DALLยทE 2, Imagen ๋“ฑ์ด ๋ชจ๋‘ ์ด ๊ณ„์—ด์ด๋‹ค.

ํ•ต์‹ฌ ์•„์ด๋””์–ด๋Š” ๋‘ ๋‹จ๊ณ„๋กœ ๋‚˜๋‰œ๋‹ค.

  1. Forward process (ํ™•์‚ฐ) โ€” ์›๋ณธ ์ด๋ฏธ์ง€์— ๊ฐ€์šฐ์‹œ์•ˆ ๋…ธ์ด์ฆˆ๋ฅผ T ์Šคํ…์— ๊ฑธ์ณ ์กฐ๊ธˆ์”ฉ ๋”ํ•ด, ๊ฒฐ๊ตญ ์ˆœ์ˆ˜ํ•œ ๊ฐ€์šฐ์‹œ์•ˆ ๋…ธ์ด์ฆˆ ์ƒํƒœ๋กœ ๋งŒ๋“ ๋‹ค.
  2. Reverse process (๋ณต์›) โ€” ๋…ธ์ด์ฆˆ ์ƒํƒœ์—์„œ ์‹œ์ž‘ํ•ด ์Šคํ…๋งˆ๋‹ค ๋…ธ์ด์ฆˆ๋ฅผ ์˜ˆ์ธกยท์ œ๊ฑฐํ•˜๋ฉฐ ์›๋ณธ์— ๊ฐ€๊นŒ์šด ์ด๋ฏธ์ง€๋ฅผ ๋ณต์›ํ•œ๋‹ค.

Forward Process

์‹œ๊ฐ„ t์—์„œ ์ด๋ฏธ์ง€ x_t๋Š” ์ด์ „ ์ด๋ฏธ์ง€ x_{t-1}์— ์•ฝ๊ฐ„์˜ ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•œ ๊ฒƒ์ด๋‹ค.

x_t = โˆš(1 - ฮฒt) * x_{t-1} + โˆšฮฒt * ฮต,   ฮต ~ N(0, I)

ฮฒt๋Š” ์Šคํ…๋งˆ๋‹ค์˜ ๋…ธ์ด์ฆˆ ๊ฐ•๋„(์Šค์ผ€์ค„)์ด๊ณ , T ์Šคํ…์ด ์ง€๋‚˜๋ฉด x_T๋Š” ๊ฑฐ์˜ ์ˆœ์ˆ˜ํ•œ ๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ๊ฐ€ ๋œ๋‹ค.

์ˆ˜์‹์„ ์ „๊ฐœํ•˜๋ฉด ์ž„์˜์˜ t ์Šคํ…์—์„œ์˜ x_t๋ฅผ ์›๋ณธ x_0์œผ๋กœ๋ถ€ํ„ฐ ๋‹ซํžŒ ํ˜•ํƒœ(closed form)๋กœ ๋ฐ”๋กœ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๋‹ค.

x_t = โˆšแพฑt * x_0 + โˆš(1 - แพฑt) * ฮต

แพฑt๋Š” ฮฒ1~ฮฒt์˜ ๋ˆ„์ ๊ณฑ์œผ๋กœ ์ •์˜๋œ๋‹ค. ํ•™์Šต ์ค‘์— ์ž„์˜์˜ t๋ฅผ ์ƒ˜ํ”Œ๋งํ•ด์„œ x_t๋ฅผ ํ•œ ๋ฒˆ์— ๋งŒ๋“ค ์ˆ˜ ์žˆ์–ด ํšจ์œจ์ ์ด๋‹ค.


Reverse Process (U-Net์˜ ์—ญํ• )

๋ชจ๋ธ์€ x_t์™€ ์‹œ๊ฐ„ t๋ฅผ ์ž…๋ ฅ๋ฐ›์•„ ํ•ด๋‹น ์Šคํ…์—์„œ ์ถ”๊ฐ€๋œ ๋…ธ์ด์ฆˆ ฮต๋ฅผ ์˜ˆ์ธกํ•˜๋„๋ก ํ•™์Šต๋œ๋‹ค. ์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” U-Net ์•„ํ‚คํ…์ฒ˜๊ฐ€ ์ฃผ๋กœ ์“ฐ์ธ๋‹ค.

ํ•™์Šต ์†์‹ค์€ ๋‹จ์ˆœํžˆ ์‹ค์ œ ๋…ธ์ด์ฆˆ ฮต์™€ ์˜ˆ์ธก ๋…ธ์ด์ฆˆ ฮต_ฮธ ์‚ฌ์ด์˜ MSE๋‹ค.

L = E[||ฮต - ฮต_ฮธ(x_t, t)||ยฒ]

์ถ”๋ก  ์‹œ์—๋Š” ์ˆœ์ˆ˜ ๊ฐ€์šฐ์‹œ์•ˆ ๋…ธ์ด์ฆˆ x_T์—์„œ ์‹œ์ž‘ํ•ด T โ†’ 0 ๋ฐฉํ–ฅ์œผ๋กœ ์Šคํ…๋งˆ๋‹ค ๋…ธ์ด์ฆˆ๋ฅผ ์ œ๊ฑฐํ•˜๋ฉฐ ์ตœ์ข… ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“ ๋‹ค.


Stable Diffusion์˜ ํ•ต์‹ฌ: Latent Diffusion

ํ”ฝ์…€ ๊ณต๊ฐ„์—์„œ ์ง์ ‘ Diffusion์„ ์ˆ˜ํ–‰ํ•˜๋ฉด ๊ณ ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€์ผ์ˆ˜๋ก ์—ฐ์‚ฐ ๋น„์šฉ์ด ํญ๋ฐœ์ ์œผ๋กœ ์ฆ๊ฐ€ํ•œ๋‹ค. Stable Diffusion์€ ์ด๋ฅผ ์ž ์žฌ ๊ณต๊ฐ„(latent space) ์—์„œ ์ˆ˜ํ–‰ํ•ด ํ•ด๊ฒฐํ•œ๋‹ค.

์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ์€ ์„ธ ์ปดํฌ๋„ŒํŠธ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.

์ปดํฌ๋„ŒํŠธ์—ญํ• 
VAE (Variational Autoencoder)์ด๋ฏธ์ง€๋ฅผ ์ €์ฐจ์› ์ž ์žฌ ๋ฒกํ„ฐ๋กœ ์ธ์ฝ”๋”ฉ / ๋ณต์›
U-Net (+ Attention)์ž ์žฌ ๊ณต๊ฐ„์—์„œ Diffusion์˜ ๋…ธ์ด์ฆˆ ์˜ˆ์ธก
Text Encoder (CLIP ๋“ฑ)ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ๋ณ€ํ™˜ํ•ด U-Net์— ์กฐ๊ฑด ์ œ๊ณต
  1. ์ด๋ฏธ์ง€๋ฅผ VAE ์ธ์ฝ”๋”๋กœ ์••์ถ•ํ•ด latent z๋ฅผ ์–ป๋Š”๋‹ค (์˜ˆ: 512ร—512 โ†’ 64ร—64ร—4).
  2. z์— Forward diffusion์„ ์ ์šฉํ•ด ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ํ•œ๋‹ค.
  3. U-Net์ด latent ๊ณต๊ฐ„์—์„œ ๋…ธ์ด์ฆˆ๋ฅผ ์˜ˆ์ธกํ•˜๊ณ  ์ œ๊ฑฐํ•œ๋‹ค.
  4. ๋ณต์›๋œ latent๋ฅผ VAE ๋””์ฝ”๋”๋กœ ๋‹ค์‹œ ํ”ฝ์…€ ๊ณต๊ฐ„์œผ๋กœ ๋ณ€ํ™˜ํ•œ๋‹ค.

ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๋Š” Cross-Attention์„ ํ†ตํ•ด U-Net์˜ ๊ฐ ๋ ˆ์ด์–ด์— ์กฐ๊ฑด์œผ๋กœ ์ฃผ์ž…๋œ๋‹ค.


Classifier-Free Guidance (CFG)

ํ…์ŠคํŠธ ์กฐ๊ฑด์„ ์–ผ๋งˆ๋‚˜ ๊ฐ•ํ•˜๊ฒŒ ๋ฐ˜์˜ํ• ์ง€ ์กฐ์ ˆํ•˜๋Š” ๊ธฐ๋ฒ•์ด๋‹ค. ๊ฐ™์€ U-Net์„ ๋‘ ๋ฒˆ ์‹คํ–‰ํ•œ๋‹ค.

  1. ํ…์ŠคํŠธ ์กฐ๊ฑด ์žˆ์Œ โ†’ ์กฐ๊ฑด๋ถ€ ๋…ธ์ด์ฆˆ ์˜ˆ์ธก ฮต_c
  2. ํ…์ŠคํŠธ ์กฐ๊ฑด ์—†์Œ(๋นˆ ํ”„๋กฌํ”„ํŠธ) โ†’ ๋ฌด์กฐ๊ฑด ๋…ธ์ด์ฆˆ ์˜ˆ์ธก ฮต_u

์ตœ์ข… ๋…ธ์ด์ฆˆ ์˜ˆ์ธก์€ ๋‘ ๊ฒฐ๊ณผ๋ฅผ ์„ ํ˜• ๋ณด๊ฐ„ํ•œ๋‹ค.

ฮต_final = ฮต_u + guidance_scale * (ฮต_c - ฮต_u)

guidance_scale(CFG ์Šค์ผ€์ผ)์ด ๋†’์„์ˆ˜๋ก ํ…์ŠคํŠธ์— ๋” ์ถฉ์‹คํ•œ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋˜์ง€๋งŒ, ๋„ˆ๋ฌด ๋†’์œผ๋ฉด ๊ณผํฌํ™”ยท์™œ๊ณก์ด ๋ฐœ์ƒํ•œ๋‹ค. ๋ณดํ†ต 7~12 ์‚ฌ์ด๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.


์ƒ˜ํ”Œ๋ง ์Šค์ผ€์ค„๋Ÿฌ

Reverse process๋ฅผ T ์Šคํ… ์ „๋ถ€ ์ˆ˜ํ–‰ํ•˜๋ฉด ๋А๋ฆฌ๋‹ค. ์‹ค์šฉ์ ์œผ๋กœ๋Š” ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ํ†ตํ•ด ์Šคํ… ์ˆ˜๋ฅผ ์ค„์ธ๋‹ค.

์Šค์ผ€์ค„๋ŸฌํŠน์ง•
DDPM์›๋ž˜ ๋…ผ๋ฌธ ๋ฐฉ์‹, 1000 ์Šคํ… ํ•„์š”
DDIM๊ฒฐ์ •๋ก ์  ์ƒ˜ํ”Œ๋ง, 20~50 ์Šคํ…์œผ๋กœ ๊ฐ€๋Šฅ
DPM++๊ณ ์ฐจ ์†”๋ฒ„, 15~25 ์Šคํ…์—์„œ ์ข‹์€ ํ’ˆ์งˆ
Euler / Euler A๋น ๋ฅด๊ณ  ์•ˆ์ •์ , ์‹ค์ œ๋กœ ๋งŽ์ด ์‚ฌ์šฉ

DDIM์€ ๋™์ผํ•œ ๋…ธ์ด์ฆˆ ์‹œ๋“œ์—์„œ ์ผ๊ด€๋œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ๋•Œ๋ฌธ์— img2img๋‚˜ inpainting์—๋„ ์ž์ฃผ ์“ฐ์ธ๋‹ค.


img2img์™€ Inpainting

img2img: ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ ์ผ๋ถ€ ๋…ธ์ด์ฆˆํ™”(denoise_strength๋กœ ์กฐ์ ˆ)ํ•œ ๋’ค Reverse process๋ฅผ ์‹คํ–‰ํ•œ๋‹ค. strength๊ฐ€ ๋‚ฎ์„์ˆ˜๋ก ์›๋ณธ์— ๊ฐ€๊นŒ์šด ์ด๋ฏธ์ง€๊ฐ€ ๋‚˜์˜จ๋‹ค.

Inpainting: ๋งˆ์Šคํฌ ์˜์—ญ๋งŒ ๋…ธ์ด์ฆˆํ™”ํ•˜๊ณ  ๋‚˜๋จธ์ง€๋Š” ์›๋ณธ์„ ์œ ์ง€ํ•œ ์ฑ„ Reverse process๋ฅผ ์‹คํ–‰ํ•œ๋‹ค. ํŠน์ • ๋ถ€๋ถ„๋งŒ ์ˆ˜์ •ํ•  ๋•Œ ์‚ฌ์šฉํ•œ๋‹ค.


GAN๊ณผ ๋น„๊ต

ํ•ญ๋ชฉGANDiffusion
ํ•™์Šต ์•ˆ์ •์„ฑ๋ถˆ์•ˆ์ • (mode collapse)์•ˆ์ •์ 
์ƒ์„ฑ ํ’ˆ์งˆ๋‚ ์นด๋กญ์ง€๋งŒ ๋‹ค์–‘์„ฑ ๋ถ€์กฑ๋†’์€ ๋‹ค์–‘์„ฑ
์†๋„๋น ๋ฆ„ (๋‹จ์ผ forward pass)๋А๋ฆผ (๋‹ค๋‹จ๊ณ„ reverse)
์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ์ถ”๊ฐ€ ๊ตฌ์กฐ ํ•„์š”CFG๋กœ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ํ†ตํ•ฉ

Diffusion ๋ชจ๋ธ์ด ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ’ˆ์งˆ๊ณผ ํ…์ŠคํŠธ ์กฐ๊ฑด ์ œ์–ด ๋ฉด์—์„œ GAN์„ ๋Œ€๋ถ€๋ถ„์˜ ํƒœ์Šคํฌ์—์„œ ์•ž์ง€๋ฅด๋ฉด์„œ ์ฃผ๋ฅ˜๊ฐ€ ๋๋‹ค.


์ •๋ฆฌ

  • Diffusion ๋ชจ๋ธ์€ Forward(๋…ธ์ด์ฆˆ ์ถ”๊ฐ€) + Reverse(๋…ธ์ด์ฆˆ ์ œ๊ฑฐ)๋ฅผ ํ•™์Šตํ•˜๋Š” ์ƒ์„ฑ ๋ชจ๋ธ์ด๋‹ค.
  • Stable Diffusion์€ VAE๋กœ ํ”ฝ์…€์„ latent๋กœ ์••์ถ•ํ•œ ๋’ค latent ๊ณต๊ฐ„์—์„œ Diffusion์„ ์ˆ˜ํ–‰ํ•ด ํšจ์œจ์„ ๋†’์ธ๋‹ค.
  • CLIP ๊ธฐ๋ฐ˜ ํ…์ŠคํŠธ ์ธ์ฝ”๋”์™€ Cross-Attention์œผ๋กœ ํ…์ŠคํŠธ ์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ์„ ๊ตฌํ˜„ํ•œ๋‹ค.
  • CFG ์Šค์ผ€์ผ๋กœ ํ…์ŠคํŠธ ์ถฉ์‹ค๋„๋ฅผ ์กฐ์ ˆํ•˜๊ณ , DDIM/DPM++ ๊ฐ™์€ ์Šค์ผ€์ค„๋Ÿฌ๋กœ ์ƒ˜ํ”Œ๋ง ์†๋„๋ฅผ ๋†’์ธ๋‹ค.