Papers
Topics
Authors
Recent
Search
2000 character limit reached

Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models

Published 12 Mar 2025 in cs.LG and cs.AI | (2503.09573v3)

Abstract: Diffusion LLMs offer unique benefits over autoregressive models due to their potential for parallelized generation and controllability, yet they lag in likelihood modeling and are limited to fixed-length generation. In this work, we introduce a class of block diffusion LLMs that interpolate between discrete denoising diffusion and autoregressive models. Block diffusion overcomes key limitations of both approaches by supporting flexible-length generation and improving inference efficiency with KV caching and parallel token sampling. We propose a recipe for building effective block diffusion models that includes an efficient training algorithm, estimators of gradient variance, and data-driven noise schedules to minimize the variance. Block diffusion sets a new state-of-the-art performance among diffusion models on language modeling benchmarks and enables generation of arbitrary-length sequences. We provide the code, along with the model weights and blog post on the project page: https://m-arriola.com/bd3lms

Summary

  • The paper introduces BD3-LMs that integrate autoregressive block-level generation with discrete denoising diffusion, enabling flexible-length text synthesis.
  • It leverages a two-pass and vectorized training strategy using custom attention masks and KV caching to boost efficiency by 20–25%.
  • Experimental results demonstrate improved perplexity on benchmarks like LM1B and OpenWebText, with enhanced sample quality and accelerated inference.

This paper introduces Block Discrete Denoising Diffusion LLMs (BD3-LMs), a novel class of models that bridges the gap between autoregressive (AR) and discrete denoising diffusion models for language generation. The primary motivation is to overcome key limitations of both paradigms: diffusion models often struggle with likelihood modeling, are restricted to fixed-length generation, and lack efficient inference mechanisms like KV caching, while AR models generate tokens sequentially, limiting speed.

BD3-LMs operate by being autoregressive over blocks of tokens while performing discrete denoising diffusion within each block. This hybrid approach allows for flexible-length generation and improves inference efficiency through KV caching and parallel token sampling within blocks.

Core Concepts and Implementation

1. Model Architecture and Likelihood:

  • A sequence of LL tokens x\mathbf{x} is divided into BB blocks, each of length LL', so L=BLL = B \cdot L'.
  • The log-likelihood is factorized autoregressively over these blocks:

    logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})

  • Each conditional probability pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b}) is modeled by a discrete diffusion process specific to block bb, conditioned on previously generated blocks x<b\mathbf{x}^{<b}.
  • A single transformer neural network fθf_\theta parameterizes the base denoiser for all blocks. It uses a block-causal attention mask, where tokens in block x\mathbf{x}0 attend to other tokens within the (potentially noised) block x\mathbf{x}1 and all clean tokens in preceding blocks x\mathbf{x}2.
  • The model supports KV caching:

    x\mathbf{x}3

    where x\mathbf{x}4 is the noised version of block x\mathbf{x}5 at timestep x\mathbf{x}6, and x\mathbf{x}7 are cached keys and values from previous blocks.

2. Training Objective:

  • The training objective is derived by applying the Negative ELBO (NELBO) to each block-conditional term:

    x\mathbf{x}8

    where x\mathbf{x}9 is the standard diffusion NELBO for block BB0 conditioned on BB1.

  • For masked BD3-LMs (using a masking noise process), a simplified objective is adopted:

    BB2

    where BB3 defines the noise schedule (probability of a token not being masked at time BB4), and BB5 is its derivative.

3. Efficient Training Algorithm (Algorithm 1):

  • Naively computing the loss would require BB6 separate forward passes for denoising each block, as denoising block BB7 uses a noised BB8 while conditioning on clean previous blocks BB9.
  • Two-Pass Approach:

1. First Pass (KV Cache Precomputation): Compute keys and values LL'0 for the entire clean sequence LL'1 in one forward pass: LL'2. 2. Second Pass (Denoising): For each block LL'3, sample noise levels LL'4 and create noised blocks LL'5. Compute denoised predictions for all blocks simultaneously using the precomputed KV cache: LL'6. - Vectorized Single-Pass Training: An even more efficient method concatenates the noisy data LL'7 and clean data LL'8 into a single input sequence of length LL'9. A custom attention mask (detailed in Appendix \ref{suppl:masks}) is designed so that noisy tokens attend to other noisy tokens in their block and to clean tokens in preceding blocks. This leverages efficient attention kernels like FlashAttention or the proposed FlexAttention (Appendix \ref{suppl:flex-attention-kernels}), yielding a 20-25% training speed-up over the two-pass approach.

bb0

4. Efficient Sampling Algorithm (Algorithm 2):

  • Blocks are generated sequentially.
  • For each block L=BLL = B \cdot L'0:
    1. Sample the clean block L=BLL = B \cdot L'1 using a diffusion sampling procedure (e.g., D3PM sampler) conditioned on previously generated clean blocks L=BLL = B \cdot L'2 (via their cached keys and values L=BLL = B \cdot L'3). This step involves multiple denoising steps within the block.
    2. Compute and cache keys and values for the newly sampled block L=BLL = B \cdot L'4: L=BLL = B \cdot L'5.
    3. Append L=BLL = B \cdot L'6 to the generated sequence and update the overall KV cache.
  • This allows for arbitrary-length sequence generation and benefits from parallel generation within each block.

bb1

5. Addressing Gradient Variance and Improving Performance:

  • A key finding is that the perplexity gap between diffusion models and AR models can be attributed to high variance in the gradients of the diffusion objective during training.
  • Case Study (L=BLL = B \cdot L'7): When block size is 1, BD3-LM is theoretically equivalent to AR. However, standard masked diffusion (masking ~50% of tokens) results in higher perplexity than AR. This is because the diffusion objective effectively trains on fewer tokens per step. By using a "full masking" schedule (L=BLL = B \cdot L'8), the BD3-LM (L=BLL = B \cdot L'9) matches AR performance, and gradient variance is significantly reduced.
  • Clipped Noise Schedules: To minimize gradient variance for logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})0, the paper proposes "clipped" noise schedules where mask rates (logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})1) are sampled uniformly from a sub-interval logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})2 instead of logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})3. This avoids extreme masking rates (very few or very many masks) which provide poor learning signals and lead to high-variance gradients.
  • Data-Driven Schedule Optimization: The optimal logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})4 and logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})5 are found to be block-size dependent. They are learned adaptively during training by performing a grid search at regular intervals to find values that minimize the variance of the NELBO estimator (used as a proxy for gradient variance):

    logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})6

Experimental Results and Practical Implications

  • State-of-the-Art Perplexity: BD3-LMs achieve new state-of-the-art perplexities among discrete diffusion models on LM1B and OpenWebText benchmarks, significantly closing the gap to AR models. For example, on LM1B, BD3-LM (logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})7) achieves logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})8 PPL, compared to MDLM's logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})9 PPL.
  • Variable-Length Generation: BD3-LMs can generate sequences much longer than their training context (e.g., up to ~10x longer than fixed-length diffusion models like SEDD on OWT).
  • Improved Sample Quality: BD3-LMs show better generative perplexity (Gen. PPL, evaluated by GPT2-Large) compared to prior diffusion methods like SEDD, MDLM, and SSD-LM, often with an order of magnitude fewer generation steps (NFEs) than methods like SSD-LM.
    • For pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})0, BD3-LM (pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})1) achieves Gen. PPL of 23.6 with 2K NFEs, while SSD-LM (pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})2, comparable NFEs) gets 281.9, and MDLM gets 41.3.
  • Efficiency of Clipped Schedules: Data-driven clipped noise schedules are shown to reduce training variance and improve test perplexity compared to standard linear or other common schedules. The optimal clipping range varies with block size (e.g., heavier masking for smaller pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})3).
  • Computational Cost: Training BD3-LMs is inherently more expensive than standard diffusion due to the block-autoregressive nature and potentially multiple passes or larger effective sequence lengths. The proposed vectorized training algorithm keeps this overhead manageable (within <2x of standard diffusion). Pre-training with a standard diffusion loss before fine-tuning with the block diffusion objective can further reduce costs.

Implementation Considerations

  • Computational Requirements: Training requires careful management of memory and computation, especially with the vectorized approach (concatenating sequences). Efficient attention kernels (FlashAttention, FlexAttention) are crucial.
  • Choosing Block Size (pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})4): The optimal block size is task-dependent. Smaller pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})5 approaches AR behavior (more sequential steps, potentially better perplexity). Larger pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})6 increases parallelism but might make learning harder or loosen the NELBO bound more. Experiments show pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})7 often gives the best perplexity.
  • Noise Schedule Tuning: Implementing the data-driven clipped schedule optimization requires periodic evaluation of NELBO variance for different pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})8 ranges. This adds some overhead but is shown to be beneficial.
  • KV Cache Implementation: Standard transformer KV caching mechanisms can be adapted. The key is to correctly pass and update the cache across block generation steps during sampling, and to use it appropriately during the second pass or vectorized pass of training.
  • Deployment: For inference, the block-sequential generation means latency will be higher than fully parallel diffusion models but potentially lower than token-by-token AR models if pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})9 is large enough and intra-block parallelism is exploited.

In summary, BD3-LMs offer a practical framework for building high-quality, flexible-length LLMs that combine strengths from AR and diffusion paradigms. The paper provides concrete algorithms for training and sampling, addresses the critical issue of gradient variance through novel noise schedules, and demonstrates strong empirical results. The code and model weights are made available, facilitating adoption and further research.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We found no open problems mentioned in this paper.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 47 tweets with 3364 likes about this paper.