- The paper introduces BD3-LMs that integrate autoregressive block-level generation with discrete denoising diffusion, enabling flexible-length text synthesis.
- It leverages a two-pass and vectorized training strategy using custom attention masks and KV caching to boost efficiency by 20–25%.
- Experimental results demonstrate improved perplexity on benchmarks like LM1B and OpenWebText, with enhanced sample quality and accelerated inference.
This paper introduces Block Discrete Denoising Diffusion LLMs (BD3-LMs), a novel class of models that bridges the gap between autoregressive (AR) and discrete denoising diffusion models for language generation. The primary motivation is to overcome key limitations of both paradigms: diffusion models often struggle with likelihood modeling, are restricted to fixed-length generation, and lack efficient inference mechanisms like KV caching, while AR models generate tokens sequentially, limiting speed.
BD3-LMs operate by being autoregressive over blocks of tokens while performing discrete denoising diffusion within each block. This hybrid approach allows for flexible-length generation and improves inference efficiency through KV caching and parallel token sampling within blocks.
Core Concepts and Implementation
1. Model Architecture and Likelihood:
- A sequence of L tokens x is divided into B blocks, each of length L′, so L=B⋅L′.
- The log-likelihood is factorized autoregressively over these blocks:
logpθ(x)=b=1∑Blogpθ(xb∣x<b)
- Each conditional probability pθ(xb∣x<b) is modeled by a discrete diffusion process specific to block b, conditioned on previously generated blocks x<b.
- A single transformer neural network fθ parameterizes the base denoiser for all blocks. It uses a block-causal attention mask, where tokens in block x0 attend to other tokens within the (potentially noised) block x1 and all clean tokens in preceding blocks x2.
- The model supports KV caching:
x3
where x4 is the noised version of block x5 at timestep x6, and x7 are cached keys and values from previous blocks.
2. Training Objective:
- The training objective is derived by applying the Negative ELBO (NELBO) to each block-conditional term:
x8
where x9 is the standard diffusion NELBO for block B0 conditioned on B1.
- For masked BD3-LMs (using a masking noise process), a simplified objective is adopted:
B2
where B3 defines the noise schedule (probability of a token not being masked at time B4), and B5 is its derivative.
3. Efficient Training Algorithm (Algorithm 1):
- Naively computing the loss would require B6 separate forward passes for denoising each block, as denoising block B7 uses a noised B8 while conditioning on clean previous blocks B9.
- Two-Pass Approach:
1. First Pass (KV Cache Precomputation): Compute keys and values L′0 for the entire clean sequence L′1 in one forward pass: L′2.
2. Second Pass (Denoising): For each block L′3, sample noise levels L′4 and create noised blocks L′5. Compute denoised predictions for all blocks simultaneously using the precomputed KV cache: L′6.
- Vectorized Single-Pass Training: An even more efficient method concatenates the noisy data L′7 and clean data L′8 into a single input sequence of length L′9. A custom attention mask (detailed in Appendix \ref{suppl:masks}) is designed so that noisy tokens attend to other noisy tokens in their block and to clean tokens in preceding blocks. This leverages efficient attention kernels like FlashAttention or the proposed FlexAttention (Appendix \ref{suppl:flex-attention-kernels}), yielding a 20-25% training speed-up over the two-pass approach.
b0
4. Efficient Sampling Algorithm (Algorithm 2):
- Blocks are generated sequentially.
- For each block L=B⋅L′0:
- Sample the clean block L=B⋅L′1 using a diffusion sampling procedure (e.g., D3PM sampler) conditioned on previously generated clean blocks L=B⋅L′2 (via their cached keys and values L=B⋅L′3). This step involves multiple denoising steps within the block.
- Compute and cache keys and values for the newly sampled block L=B⋅L′4: L=B⋅L′5.
- Append L=B⋅L′6 to the generated sequence and update the overall KV cache.
- This allows for arbitrary-length sequence generation and benefits from parallel generation within each block.
b1
5. Addressing Gradient Variance and Improving Performance:
Experimental Results and Practical Implications
- State-of-the-Art Perplexity: BD3-LMs achieve new state-of-the-art perplexities among discrete diffusion models on LM1B and OpenWebText benchmarks, significantly closing the gap to AR models. For example, on LM1B, BD3-LM (logpθ(x)=b=1∑Blogpθ(xb∣x<b)7) achieves logpθ(x)=b=1∑Blogpθ(xb∣x<b)8 PPL, compared to MDLM's logpθ(x)=b=1∑Blogpθ(xb∣x<b)9 PPL.
- Variable-Length Generation: BD3-LMs can generate sequences much longer than their training context (e.g., up to ~10x longer than fixed-length diffusion models like SEDD on OWT).
- Improved Sample Quality: BD3-LMs show better generative perplexity (Gen. PPL, evaluated by GPT2-Large) compared to prior diffusion methods like SEDD, MDLM, and SSD-LM, often with an order of magnitude fewer generation steps (NFEs) than methods like SSD-LM.
- For pθ(xb∣x<b)0, BD3-LM (pθ(xb∣x<b)1) achieves Gen. PPL of 23.6 with 2K NFEs, while SSD-LM (pθ(xb∣x<b)2, comparable NFEs) gets 281.9, and MDLM gets 41.3.
- Efficiency of Clipped Schedules: Data-driven clipped noise schedules are shown to reduce training variance and improve test perplexity compared to standard linear or other common schedules. The optimal clipping range varies with block size (e.g., heavier masking for smaller pθ(xb∣x<b)3).
- Computational Cost: Training BD3-LMs is inherently more expensive than standard diffusion due to the block-autoregressive nature and potentially multiple passes or larger effective sequence lengths. The proposed vectorized training algorithm keeps this overhead manageable (within <2x of standard diffusion). Pre-training with a standard diffusion loss before fine-tuning with the block diffusion objective can further reduce costs.
Implementation Considerations
- Computational Requirements: Training requires careful management of memory and computation, especially with the vectorized approach (concatenating sequences). Efficient attention kernels (FlashAttention, FlexAttention) are crucial.
- Choosing Block Size (pθ(xb∣x<b)4): The optimal block size is task-dependent. Smaller pθ(xb∣x<b)5 approaches AR behavior (more sequential steps, potentially better perplexity). Larger pθ(xb∣x<b)6 increases parallelism but might make learning harder or loosen the NELBO bound more. Experiments show pθ(xb∣x<b)7 often gives the best perplexity.
- Noise Schedule Tuning: Implementing the data-driven clipped schedule optimization requires periodic evaluation of NELBO variance for different pθ(xb∣x<b)8 ranges. This adds some overhead but is shown to be beneficial.
- KV Cache Implementation: Standard transformer KV caching mechanisms can be adapted. The key is to correctly pass and update the cache across block generation steps during sampling, and to use it appropriately during the second pass or vectorized pass of training.
- Deployment: For inference, the block-sequential generation means latency will be higher than fully parallel diffusion models but potentially lower than token-by-token AR models if pθ(xb∣x<b)9 is large enough and intra-block parallelism is exploited.
In summary, BD3-LMs offer a practical framework for building high-quality, flexible-length LLMs that combine strengths from AR and diffusion paradigms. The paper provides concrete algorithms for training and sampling, addresses the critical issue of gradient variance through novel noise schedules, and demonstrates strong empirical results. The code and model weights are made available, facilitating adoption and further research.