Exploring Medusa and Multi-Token Prediction
🌈 Abstract
The article discusses the "MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" paper, which proposes an architectural change to large language models (LLMs) that can achieve a 2x-3x speed-up on existing hardware.
🙋 Q&A
[01] Speculative Decoding
1. What is speculative decoding and how does it work?
- Speculative decoding is a technique to speed up LLM inferencing by predicting multiple tokens at once, rather than just one token per forward pass.
- It involves three main steps:
- Generate the token candidates
- Process the candidates
- Accept certain candidates
2. How does Medusa implement speculative decoding?
- Medusa appends multiple decoding heads to the final layer of the model to generate multiple token candidates (step 1).
- It uses "tree attention" to efficiently process the multiple candidates (step 2).
- Medusa uses either rejection sampling or a typical acceptance scheme to determine which candidates to accept (step 3).
[02] Decoding Heads & Medusa
1. What is a decoding head and how does Medusa use it?
- A decoding head takes the internal representation from a model forward pass and generates the probabilities for different tokens in the vocabulary.
- Medusa appends multiple decoding heads to the last hidden layer of the model, allowing it to predict more than one token per forward pass.
2. How does the Medusa decoding head equation work?
- The Medusa decoding head equation uses the trained weights W1 and W2, along with the SiLU activation function and a skip connection, to generate the probability of a token from the k-th head.
[03] Tree Attention
1. What is tree attention and how does it help with Medusa?
- Tree attention allows Medusa to efficiently compute the attention patterns for multiple token candidates at once, by only considering the relevant continuation tokens.
- This is done by using a mask to avoid passing information about irrelevant tokens into the attention calculation.
2. How does Medusa use probability to optimize the tree attention calculations?
- Medusa adds nodes to the attention calculation tree based on the probability of the predictions, focusing more on the higher probability candidates.
- This helps Medusa be memory efficient by only calculating attention for the most likely token continuations.
[04] Acceptance Schemes
1. What are the two acceptance schemes discussed for Medusa?
- Rejection sampling, where a separate model determines if a predicted token is good enough
- Typical acceptance, where Medusa uses the probabilities from the original model to set a threshold for accepting tokens
2. How does the typical acceptance scheme work and how does it relate to temperature?
- The typical acceptance scheme uses the original model's probability distribution to determine if a predicted token should be accepted, based on thresholds ε and δ.
- As temperature increases, the probability distribution changes, allowing more lower-probability tokens to be accepted, leading to faster but potentially more creative results.
[05] Self-Distillation
1. How does Medusa use self-distillation to train the model?
- Medusa starts with a high-quality backbone model and fine-tunes it to add the Medusa heads, rather than training from scratch.
- To do this, Medusa uses a loss function based on Kullback-Leibler divergence to match the original model's probability distributions.
- The authors recommend using LoRA to efficiently fine-tune the model while maintaining the original weights.
2. What are the two Medusa fine-tuning approaches discussed?
- Medusa-1: Only fine-tunes the Medusa heads, keeping the backbone model weights frozen.
- Medusa-2: Fine-tunes both the Medusa heads and the backbone model weights, using separate learning rates.