๐ŸŒณ LLaDA; Large Language Diffusion Models (2025-02)

์ตœ๊ทผ ๋“ฑ์žฅํ•œ Diffusion Models๋“ค์€ ARMs (Autoregressive Models) ๋งŒํผ ์„ฑ๋Šฅ์ด ๊ดœ์ฐฎ์€ ํŽธ์ด๊ณ , context-awareness ์˜์—ญ์—์„œ๋Š” ์„ฑ๋Šฅ์ด ๋” ๊ฐ•ํ•˜๋‹ค๋Š” ํ‰์ด ๋‚˜์˜ค๊ณ  ์žˆ๋‹ค. โ†’ DLMs์ด ์ „ํ†ต ARMs ๋Œ€์ฒดํ•  ์ƒˆ๋กœ์šด ๋Œ€์•ˆ์œผ๋กœ ๋ถ€์ƒํ•˜๊ณ  ์žˆ๋Š” ๊ฒƒ ๊ฐ™๋‹ค.

image_1.png

1/ Background : Diffusion Models

image_2.png

Diffusion Model์€ generative models์˜ ํ•œ ์ข…๋ฅ˜๋กœ, random noise๋ฅผ ์ ์ง„์ ์œผ๋กœ ์ถ”๊ฐ€ (forward), ์ ์ง„์ ์œผ๋กœ denoise (reverse)ํ•˜๋ฉด์„œ ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.

๋‹ค์–‘ํ•œ noise๋ฅผ ์ถ”๊ฐ€ํ•˜๊ณ , ์ด๋ฅผ ๋ณต์›ํ•˜๋Š” ๋ฒ•์„ ๋ชจ๋ธ์ด ํ•™์Šตํ•˜๋„๋ก ํ›ˆ๋ จ์‹œํ‚ค๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค. ์ƒ์„ฑ ์‹œ์—๋Š” random input์„ ๋ฐ˜๋ณต์ ์œผ๋กœ ์—…๋ฐ์ดํŠธํ•ด ๊ฒฐ๊ณผ๋ฅผ ์ ์ฐจ ์ •์ œํ•˜๋ฉฐ, ํšจ์œจ์ ์ธ sampling์„ ํ†ตํ•ด ์ ์€ step์œผ๋กœ๋„ ๊ณ ํ’ˆ์งˆ ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋‹ค. ์ดํ›„ conditioning, guidance ๊ธฐ๋ฒ•์„ ํ†ตํ•ด ์ถœ๋ ฅ ์กฐ์ ˆ์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

์ด๋ฏธ์ง€ ์ƒ์„ฑ ๋ถ„์•ผ์—์„œ ์„ฑ๊ณต์ ์ด์—ˆ๋Š”๋ฐ, ์ตœ๊ทผ ํ…์ŠคํŠธ ์ƒ์„ฑ์—๋„ ์ ์šฉ๋˜๋Š” ์›€์ง์ž„์„ ๋ณด์ด๊ณ  ์žˆ๋‹ค.

ARMs์€ ์ˆœ์ฐจ์ ์œผ๋กœ ์ƒ์„ฑํ•˜๋Š” ๋ฐ˜๋ฉด, diffusion ๋ชจ๋ธ์€ ํ…์ŠคํŠธ ์ „์ฒด๋ฅผ ๋™์‹œ์— ์ƒ์„ฑํ•˜๊ณ , ๋ฐ˜๋ณต์ ์ธ ์ˆ˜์ • ๊ณผ์ •์„ ๊ฑฐ์ณ ์ตœ์ข… ์ถœ๋ ฅ์„ ์™„์„ฑํ•œ๋‹ค. โ†’ ์ƒ์„ฑ ์†๋„, ํšจ์œจ์„ฑ ํ–ฅ์ƒ e.g. Inception Labs์˜ Mercury

# Set the noise scheduler
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)

# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

losses = []

for epoch in range(30):
    for step, batch in enumerate(train_dataloader):
        clean_images = batch["images"].to(device)
        *# Sample noise to add to the images*
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs = clean_images.shape[0]

        *# Sample a random timestep for each image*
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()

        *# Add noise to the clean images according to the noise magnitude at each timestep*
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        *# Get the model prediction*
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

        *# Calculate the loss*
        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)
        losses.append(loss.item())

        *# Update the model parameters with the optimizer*
        optimizer.step()
        optimizer.zero_grad()

    if (epoch + 1) % 5 == 0:
        loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
        print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")

2/ TL;DR

LLaDA๋Š” pre-training๊ณผ supervised fine-tuning (SFT) ๋ฐฉ์‹์œผ๋กœ scratch๋ถ€ํ„ฐ ํ•™์Šต๋œ diffusion-based LLM์œผ๋กœ, masking (forward process)์™€ denoising (reverse process - predicting masked tokens) ๊ณผ์ •์„ ํ†ตํ•ด ํ™•๋ฅ ์  ์ƒ์„ฑ ๋Šฅ๋ ฅ์„ ๊ฐ–์ท„๋‹ค. ARMs (e.g. LLaMA3-8B) ๋Œ€๋น„ ๋†’์€ scalability,์™€ in-context learning ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ฃผ๋ฉฐ, instruction-following ๋Šฅ๋ ฅ๋„ ์šฐ์ˆ˜ํ•˜๋‹ค. ํŠนํžˆ reversal curse ๋ฌธ์ œ๋ฅผ ๊ทน๋ณตํ•˜๋ฉฐ โ€˜Reversal Poem Completionโ€™ task์—์„œ GPT-4o๋ณด๋‹ค ๋›ฐ์–ด๋‚œ ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ค€๋‹ค. โ†’ ARMs์˜ ๋Œ€์•ˆ ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์€ ๊ฒƒ ๊ฐ™๋‹ค.

Case

Prompt:ย *Explain what artificial intelligence is.*

https://ml-gsai.github.io/LLaDA-demo/

3/ LLaDA (Largeย Languageย Diffusion with mAsking)

image_3.png

Idea

Diffusion์€ ์ผ๋ฐ˜์ ์œผ๋กœ ์ด๋ฏธ์ง€ ์ƒ์„ฑ์— ์‚ฌ์šฉ๋˜๋Š” ๋ฐฉ์‹์ด๋‹ค. โ†’ ํ…์ŠคํŠธ์˜ ๊ฒฝ์šฐ, ์ด๋ฏธ์ง€์ฒ˜๋Ÿผ continuous space๊ฐ€ ์•„๋‹ˆ๋ฏ€๋กœ, discrete latent space์—์„œ diffusion์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ์‘์šฉํ•œ๋‹ค.

  • ์ด๋ฏธ์ง€ : ์—ฐ์†์ ์ธ ํ”ฝ์…€๊ฐ’ (์—ฐ์† ํ™•๋ฅ  ๋ถ„ํฌ) ๋ณดํ†ต UNET ๊ธฐ๋ฐ˜ Gaussian Noise ์ ์ง„์ ์œผ๋กœ ์ถ”๊ฐ€ โ†’ ๋…ธ์ด์ฆˆ ์ œ๊ฑฐํ•˜๋ฉด์„œ ์ด๋ฏธ์ง€ ๋ณต์›

  • ํ…์ŠคํŠธ : ์ด์‚ฐ์ ์ธ ํ† ํฐ (์ด์‚ฐ ํ™•๋ฅ  ๋ถ„ํฌ) โ€œdiscrete diffusionโ€ Transfomer ๊ธฐ๋ฐ˜ LLaDA : 0 ~ 1 ์‚ฌ์ด์˜ ์ž„์˜์˜ masking ๋น„์œจ์„ ์‚ฌ์šฉํ•ด ํ† ํฐ์„ ์ ์ง„์ ์œผ๋กœ masking โ†’ masking๋œ token ์˜ˆ์ธกํ•˜๋ฉด์„œ ๋ณต์›ํ•œ๋‹ค.

    $$ P(x_{ti} | x_{i}) = \begin{cases}1 - t & \text{if} \; x_{ti} = x_{i} \\ t & \text{if} \; x_{ti} = [MASK]\end{cases} $$

Methods

image_4.png

LLaDA๋Š” masked diffusion model๋กœ, pre-training, SFT(Supervised Fine-Tuning), sampling ์„ธ ๊ณผ์ •์œผ๋กœ ๋™์ž‘ ๊ณผ์ •์„ ๋‚˜๋ˆ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

  • Pre-training
    • input sequence์˜ ๋ชจ๋“  token์„ ๋ฌด์ž‘์œ„๋กœ ๋…๋ฆฝ์ ์œผ๋กœ masking.

    • masking ratio $t$๋Š” [0,1]์—์„œ ๋žœ๋ค์œผ๋กœ ์ƒ˜ํ”Œ๋ง (๊ฐ token์€ $t$์˜ ํ™•๋ฅ ๋กœ masking๋จ)

      $t=1$์—์„œ ๋ชจ๋“  token์ด masking๋˜๊ณ , ์ด ์‹œ์ ์— ์ƒ์„ฑ๋œ sequence๋ฅผ $x_1$

    • ๋ชจ๋ธ์€ mask predictor๋กœ ๊ฐ position์˜ masked token์„ ๋ณต์›ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ํ•™์Šต.

    • diffusion์˜ forward process์™€ ์œ ์‚ฌํ•˜๊ฒŒ ๋ฐ์ดํ„ฐ๋ฅผ ์ ์ฐจ noise (mask)๋กœ ๋งŒ๋“œ๋Š” ๊ณผ์ •.

  • SFT (Supervised Fine-Tuning)
image_5.png
- Prompt๋Š” ๊ณ ์ •, Response ๋ถ€๋ถ„๋งŒ ๋งˆ์Šคํฌ
- ๋ชจ๋ธ์€ masking๋œ response ๋ถ€๋ถ„์„ ์˜ˆ์ธกํ•˜๋„๋ก fine-tuning.
- Mask prediction loss๋กœ ํ•™์Šต.
  • Sampling

    • $t=1$ (๋ชจ๋“  token์ด mask๋œ ์ƒํƒœ)์—์„œ ์‹œ์ž‘ํ•˜์—ฌ $t=0$๊นŒ์ง€ ์ง„ํ–‰.
    • mask predictor๋ฅผ ํ†ตํ•ด mask ๋ณต์› (masking๋œ ๋ชจ๋“  token์„ ๋™์‹œ์— ์˜ˆ์ธก) maksing๋œ token์ด ๋ฌด์—‡์ธ์ง€ ์˜ˆ์ธก.
    • ์˜ˆ์ธก ํ›„ ์ผ๋ถ€ token์€ unmask๋˜๊ณ  ์ผ๋ถ€๋Š” ๋‹ค์‹œ remask (flexible remasking)
    • ์ด๋ฅผ ๋ฐ˜๋ณตํ•ด ์ ์ง„์ ์œผ๋กœ $t=0$ (๋ชจ๋“  token์ด unmask๋œ ์ƒํƒœ) ์™„์ „ํžˆ ๋ณต์›.
  • Forward process : ์ ์ง„์  masking, Reverse process : iterative unmasking

  • Sampling์—์„œ remasking & refine ๊ณผ์ •์„ ๊ฑฐ์น˜๋ฉฐ ๋†’์€ ํ’ˆ์งˆ๊ณผ long-context ์ธ์ง€ ๋Šฅ๋ ฅ์„ ํ™•๋ณด.

  • ARMs์˜ next-token-generation ๋ฐฉ์‹๋ณด๋‹ค global context ์ดํ•ด๋„๊ฐ€ ๋” ๋†’์Œ.

Architecture

๋งˆ์Šคํฌ๋œ ์ดˆ๊ธฐ ์ƒํƒœ (xโ‚; t=1)
     โ”‚
     โ–ผ
Mask Predictor (Transformer)
     โ”‚
์˜ˆ์ธก๋œ ํ† ํฐ (์ผ๋ถ€๋Š” ๋‹ค์‹œ ๋งˆ์Šคํฌ๋จ)
     โ”‚
     โ–ผ
Remasking (Remasking ์ „๋žต์— ๋”ฐ๋ผ)
     โ”‚
     โ–ผ
๋ฐ˜๋ณต์  ์ •์ œ (t: 1 -> 0)
     โ”‚
     โ–ผ
์ตœ์ข… ํ…์ŠคํŠธ (xโ‚€; t=0)
  • Transformer ๊ธฐ๋ฐ˜ mask predictor & multi-head attention block ํฌํ•จ.

Training

Training Objective

$$ L(\theta) \triangleq -E_{t, x_0, x_t} \left[ \frac{1}{tL} \sum_{i=1}^L \mathbb{1}[x_{it} = M] \log p_\theta (x_{i0}|x_t) \right] $$
  • random timestep $t$ ์„ ํƒ.

  • ์›์‹œ ํ…์ŠคํŠธ์— noise (mask) ์ถ”๊ฐ€.

    $q(x_t|x_0)$

  • ๋ชจ๋ธ์ด masking๋œ token ์˜ˆ์ธก.

    $x_t$๋ฅผ ์ž…๋ ฅ ๋ฐ›์•„ masking๋œ ์œ„์น˜์˜ ์›๋ž˜ token์„ ์˜ˆ์ธก.

  • Loss : Cross-Entropy Loss

    ์˜ˆ์ธก๋œ token๊ณผ ์‹ค์ œ token ๊ฐ„์˜ Cross-Entropy Loss๋ฅผ ๊ณ„์‚ฐํ•˜๊ฒŒ ํ•ด ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ.

Sampling (Inference)

  1. ์ดˆ๊ธฐ ์ƒํƒœ ์„ค์ •

    ๋ชจ๋“  token์ด masking๋œ ์ดˆ๊ธฐ sequence $x_1$ ์„ค์ •.

  2. denoising ๋ฐ˜๋ณต $(t: 1 \rightarrow 0)$

    • $t$๋ฅผ 1์—์„œ 0์œผ๋กœ ์ ์ฐจ ์ค„์—ฌ๋‚˜๊ฐ€๋ฉด์„œ ๋‹จ๊ณ„ ๋ฐ˜๋ณต.

      mask predictor $p_{\theta}(x_0|x_t)$๋ฅผ ์‚ฌ์šฉํ•ด masking๋œ token ์˜ˆ์ธก. ์˜ˆ์ธก๋œ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ์‚ฌ์šฉํ•ด ๊ฐ token์ด ํ•ด๋‹น ์œ„์น˜์— ์žˆ์„ ํ™•๋ฅ  ํŒŒ์•…. ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ token์„ ์„ ํƒํ•˜๊ฑฐ๋‚˜, ํ™•๋ฅ  ๋ถ„ํฌ์— ๋”ฐ๋ผ ๋ฌด์ž‘์œ„๋กœ token ์„ ํƒ ๊ฐ€๋Šฅ.

    • Remasking : ์˜ˆ์ธก๋œ token ์ค‘ ์ผ๋ถ€๋ฅผ ๋‹ค์‹œ masking.

  3. ์ตœ์ข… ํ…์ŠคํŠธ ์ƒ์„ฑ

    $t=0$์ด๋ฉด ์ตœ์ข… ํ…์ŠคํŠธ $x_0$ ์ƒ์„ฑ.

Accelerate:

  • Pseudo-numerical solvers (DDIM)์ด๋‚˜ fewer-step sampling ๊ธฐ๋ฒ• ์‚ฌ์šฉ. diffusion์„ ๊ทผ์‚ฌํ•˜๋Š” ์ˆ˜์น˜ ํ•ด๋ฒ•์„ ์‚ฌ์šฉํ•ด ๋” ์ ์€ ๋‹จ๊ณ„๋กœ ๊ณ ํ’ˆ์งˆ ์ƒ˜ํ”Œ ์ƒ์„ฑ. diffusion ๊ณผ์ •์˜ ๊ฐ ๋‹จ๊ณ„๋ฅผ ์ตœ์ ํ™”ํ•˜๊ฑฐ๋‚˜ ๋ถˆํ•„์š”ํ•œ ๋‹จ๊ณ„๋ฅผ ์ œ๊ฑฐํ•ด ์ถ”๋ก  ์†๋„ ํ–ฅ์ƒ.
  • ์ตœ๊ทผ์—๋Š” 5~10 step๋งŒ์œผ๋กœ๋„ ์ถฉ๋ถ„ํžˆ ๋†’์€ ํ’ˆ์งˆ ์ƒ์„ฑ์ด ๊ฐ€๋Šฅํ•œ Fast DLM ์—ฐ๊ตฌ๋„ ์ง„ํ–‰ ์ค‘.

Scalability

image_6.png

Reference