Part 2 - Week 5
Mrinmaya Sachan
Alignment with Human Preferences
Instruction Tuning
Motivation
- Large Language Models (LLMs) pretrained on next-token prediction are not inherently optimized to follow explicit natural language instructions.
- Without further adaptation, they may:
- Ignore instructions.
- Produce verbose or irrelevant outputs.
- Fail to generalize to unseen task formats.
- Instruction tuning aims to bridge this gap by fine-tuning on datasets where each example is framed as an instruction–response pair.
Core Idea
- Multitask supervised fine-tuning (SFT) on a diverse set of NLP tasks, each presented with:
- A natural language instruction (prompt).
- An input (optional).
- A target output.
- The model learns to map from
(instruction, input)
→output
.
Methodology
- Data Collection
- Aggregate tasks from multiple sources (translation, summarization, QA, classification, reasoning).
- Reformat each into an instruction–input–output schema.
- Ensure diversity in:
- Task type.
- Domain.
- Instruction phrasing.
- Examples: Natural Instructions dataset, FLAN collection, Dolly, OpenAssistant.
- Fine-Tuning
- Prepend the instruction to the input sequence.
- Train with standard cross-entropy loss on the target output tokens.
- Often use a mixture of tasks with sampling weights to balance domains.
- Evaluation
- Zero-shot and few-shot performance on held-out tasks.
- Human evaluation for instruction-following quality.
Key Examples
- InstructGPT (OpenAI, 2022)
- Fine-tuned GPT-3 on curated instruction-following data.
- Showed large improvements in human preference ratings.
- Combined with RLHF for final alignment.
- FLAN (Google Research, 2021–2022)
- Early experiments on small models showed mixed results.
- Scaling to large models (e.g., PaLM 540B) and ~2,000 tasks yielded strong zero-shot/few-shot gains.
- Demonstrated that scaling both model size and task diversity is critical.
Benefits
- Improves zero-shot generalization to unseen tasks.
- Reduces need for prompt engineering.
- Produces outputs more aligned with user intent.
Limitations
- Quality and diversity of instruction data are critical — poor data can lead to overfitting to narrow styles.
- Does not explicitly optimize for human preferences (e.g., helpfulness, harmlessness, truthfulness).
- May still produce unsafe or factually incorrect outputs.
Reinforcement Learning from Human Feedback (RLHF)
Purpose
- Align LLM outputs with human preferences and values beyond what instruction tuning alone achieves.
Pipeline
- Supervised Fine-Tuning (SFT)
- Train \(\pi_{\text{SFT}}\) on human-written demonstrations.
- Serves as both the initial policy and a reference for KL regularization.
- Reward Model (RM)
- Collect comparison data: for each prompt, generate multiple completions and have humans rank them.
- Train \(R_\phi(y, x)\) to predict preferences using Bradley–Terry loss: \[ \mathcal{L}_{\text{RM}} = -\log \sigma\left( R_\phi(y^+, x) - R_\phi(y^-, x) \right) \]
- Policy Optimization
- Treat RM as the environment.
- Optimize \(\pi_\theta\) to maximize expected RM score while staying close to \(\pi_{\text{SFT}}\).
- Use Proximal Policy Optimization (PPO) with KL penalty: \[ J(\theta) = \mathbb{E}_{y \sim \pi_\theta}[R_\phi(y, x)] - \beta \, \text{KL}(\pi_\theta \| \pi_{\text{SFT}}) \]
PPO Gradient Computation
- Surrogate objective: \[ L^{\text{PPO}}(\theta) = \mathbb{E}_t \left[ \min\left( r_t(\theta) \hat{A}_t, \ \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right] \] where \(r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}\).
- Advantage estimation:
- In RLHF, the return is the RM score for the sequence.
- Token-level advantages computed via Generalized Advantage Estimation (GAE).
- Policy gradient: \[ \nabla_\theta J(\theta) \approx \mathbb{E}_t \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) \ \hat{A}_t \right] \]
- KL penalty gradient: \[ -\beta \, \nabla_\theta \mathbb{E}_{a \sim \pi_\theta} \left[ \log \pi_\theta(a|s) - \log \pi_{\text{SFT}}(a|s) \right] \]
Challenges
- Reward hacking: model exploits RM weaknesses.
- Divergence: policy drifts too far from \(\pi_{\text{SFT}}\).
- Compute cost: PPO loop is expensive for large LMs.
Direct Preference Optimization (DPO)
Motivation
- Simplify preference optimization by removing:
- Explicit RM training.
- RL loop and advantage estimation.
Method
- Start from \(\pi_{\text{SFT}}\).
- Assume optimal policy: \[ \pi^*(y|x) \propto \pi_{\text{SFT}}(y|x) \exp\left( \frac{1}{\beta} R(y, x) \right) \]
- For preference pairs \((x, y^+, y^-)\), derive loss: \[ \mathcal{L}_{\text{DPO}}(\theta) = - \mathbb{E} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y^+|x)}{\pi_{\text{SFT}}(y^+|x)} - \beta \log \frac{\pi_\theta(y^-|x)}{\pi_{\text{SFT}}(y^-|x)} \right) \right] \]
- Gradient:
- Backprop through \(\log \pi_\theta(y|x)\) (sum of token log-probs).
- \(\pi_{\text{SFT}}\) terms are constants; precompute logs.
Advantages
- No separate RM or RL loop.
- Simpler, cheaper, and stable.
- Implicit KL regularization to \(\pi_{\text{SFT}}\).
2025 Usage
- Closed-source frontier models: still use PPO-based RLHF for fine control.
- Open-weight models: often use DPO (or variants like ORPO, IPO) for efficiency.
Calibration
Definition
- A model is calibrated if its predicted probabilities match the true likelihood of correctness.
- Example: Among outputs with confidence 0.8, ~80% should be correct.
Importance
- In high-stakes applications, overconfident wrong answers are dangerous.
- Calibration improves trust and interpretability.
Calibration in LLMs
- LLMs often produce miscalibrated probabilities:
- Overconfidence in hallucinations.
- Underconfidence in correct but rare answers.
- Instruction tuning and RLHF can change calibration — sometimes negatively.
Evaluation
- Expected Calibration Error (ECE): \[ \text{ECE} = \sum_{m=1}^M \frac{|B_m|}{n} \left| \text{acc}(B_m) - \text{conf}(B_m) \right| \] where \(B_m\) is a bin of predictions with similar confidence.
Methods to Improve Calibration
- Post-hoc calibration:
- Temperature scaling: fit a scalar \(T\) on validation set to rescale logits.
- Isotonic regression: non-parametric mapping from confidence to accuracy.
- Training-time calibration:
- Add calibration-aware loss terms (e.g., focal loss, mixup).
- Multi-objective optimization balancing accuracy and calibration.
- Prompt-based calibration:
- Use control prompts to elicit more cautious or probabilistic answers.
- Bayesian approaches:
- Estimate uncertainty via ensembles or dropout.