PLaD: Preference-based Large Language Model Distillation with Pseudo-Preference Pairs
๐ Abstract
The paper presents PLaD, a novel preference-based distillation framework for transferring knowledge from large language models (LLMs) to compact student models. The key ideas are:
- Generating pseudo-preference pairs by sampling outputs from the teacher and student models, assuming the teacher output is preferred.
- Leveraging a calibration loss to re-calibrate the student's estimation of sequence likelihood, steering the student's focus towards understanding the relative quality of outputs.
- Demonstrating the effectiveness of PLaD on sequence generation tasks with various LLM teacher-student pairs, outperforming state-of-the-art knowledge distillation baselines.
๐ Q&A
[01] Preference-based Large Language Model Distillation
1. What are the key challenges in applying traditional knowledge distillation techniques to large language models (LLMs)? The paper identifies three key challenges:
- Restricted access to LLM outputs, as LLM teachers are typically only available through API calls, hindering the implementation of traditional distillation techniques that require access to full output logits or internal states.
- Significant teacher-student capacity gaps, which exacerbates the student model's limited ability to fully match the teacher LLM's output distribution.
- The inherited mis-calibration issue of LLMs, where sequences that are highly likely according to the model don't necessarily exhibit high quality for target tasks.
2. How does the PLaD framework address these challenges?
- PLaD bypasses the need for access to teacher LLM's internal states by generating pseudo-preference pairs, where the teacher output is assumed to be preferred over the student output due to the capacity difference.
- The introduction of a calibration loss directly ties the quality of generation to its likelihood, allowing for targeted optimization of output quality through calibration.
- The preference-based distillation approach shifts the student's learning focus towards understanding the relative quality of different outputs, addressing the student's inherent limitations in expressivity.
3. What are the key components of the PLaD framework? The key components are:
- Pseudo-preference pair generation: Sampling outputs from both the teacher and student models, and assuming the teacher output is preferred.
- Distillation with preference pairs: Employing a calibration loss, including a ranking calibration loss and a margin calibration loss, to re-calibrate the student's estimation of sequence likelihood.
[02] Experiments and Results
1. How does the performance of the student model distilled by PLaD compare to the initial student model and other baseline distillation methods? The student model learned by PLaD consistently outperforms the initial student model and other baseline distillation methods in terms of win rate, which measures the quality of generated text compared to target sequences. This is demonstrated across different model families (LLaMA-2 and GPT-Neo) and tasks (TL;DR summarization and Anthropic-HH dialogue generation).
2. What is the impact of using real preference pairs compared to pseudo-preference pairs? The experiments show that using real preference pairs can provide a slight improvement in performance compared to using pseudo-preference pairs. However, the gain is marginal, and the pseudo-preference pairs offer a cost-effective and time-efficient alternative without significantly compromising the learning efficacy.
3. How does the performance of the student model vary across different generation length ranges? The results indicate that PLaD maintains a relatively stable win rate improvement across varying generation lengths, with a notable peak within the most commonly encountered length range for the task. This suggests that PLaD is particularly effective in the central scenario, where the enhancement is most valuable.
4. How does PLaD scale with the amount of distillation data? As the percentage of distillation data increases, all methods show an upward trend in win rate over the target. However, PLaD demonstrates a steeper improvement curve, outperforming the other methods, particularly at higher data usage levels. This suggests that PLaD is highly efficient in refining the student model as more distillation data becomes available.
5. How does PLaD perform when applied to other LLM families, such as PaLM-2 and T5? The paper extends the evaluation of PLaD to PaLM-2 and T5 models, demonstrating the framework's broad applicability across different LLM families. The student model distilled by PLaD consistently outperforms the initial student model, showcasing the generalizability of the proposed approach.