# GDO-DPO: Gradient-Guided Disentangled DPO

Official implementation of **"Layer-wise Gradient Disentanglement: Decoupling Semantics and Preferences in Direct Preference Optimization"** (ICML 2026).

## 📖 Overview

GDO-DPO is a novel curriculum learning framework for Direct Preference Optimization (DPO) that independently regulates learning pace along two orthogonal difficulty dimensions:

- **Semantic Complexity (C_sem)**: Measures the representational demand of understanding the prompt
- **Preference Uncertainty (U_pref)**: Quantifies the ambiguity in distinguishing preferred from rejected responses

### Key Contributions

1. **Layer-wise Gradient Localization**: We discover that semantic complexity and preference uncertainty induce separable gradient patterns across network depth
   - High C_sem samples concentrate gradients in lower layers (0-16)
   - High U_pref samples activate upper layers (20-32)

2. **Bi-Dimensional Curriculum**: Unlike scalar difficulty metrics, GDO-DPO maintains two pace parameters (λ_sem, λ_unc) that independently control the active training set

3. **Gradient-Based Monitoring**: State-aware progression using:
   - **S_rep**: Representation stability monitor (EMA of representation/discrimination gradient ratio)
   - **A_disc**: Discrimination readiness monitor (validation accuracy on clear preferences)

4. **Strong Empirical Results**:
   - +4.1% over DPO on AlpacaEval 2.0 (Llama-3-8B)
   - +0.50 gain on MT-Bench Math category
   - Consistent improvements on reasoning-intensive tasks

## 📂 Project Structure

```
GDO-DPO/
├── src/
│   ├── core/
│   │   ├── difficulty_metrics.py    # Csem and Upref computation
│   │   ├── gradient_monitors.py     # Srep and Adisc monitors
│   │   └── gdo_dpo.py               # Main GDO-DPO trainer (Algorithm 1)
│   ├── data/
│   │   └── data_loader.py           # Dataset loading and preprocessing
│   ├── analysis/
│   │   ├── gradient_analysis.py     # Layer-wise gradient analysis (Figure 1)
│   │   ├── training_dynamics.py     # Curriculum dynamics visualization (Figures 4-5)
│   │   └── cka_analysis.py          # CKA similarity analysis (Table 4)
│   └── ablation/
│       └── ablation_experiments.py  # Ablation studies (Tables 5-6, 9-10)
├── scripts/
│   ├── train_gdo_dpo.py             # Main training script
│   ├── run_gradient_analysis.py     # Gradient localization analysis
│   ├── analyze_training_dynamics.py # Training dynamics visualization
│   ├── run_cka_analysis.py          # CKA similarity computation
│   └── run_ablations.py             # Ablation experiments
├── configs/
│   └── gdo_dpo_llama3_8b.yaml       # Configuration for Llama-3-8B
└── requirements.txt

## 🚀 Quick Start

### 1. Installation

```bash
# Clone the repository
git clone https://github.com/your-repo/GDO-DPO.git
cd GDO-DPO

# Create environment
conda create -n gdo-dpo python=3.10
conda activate gdo-dpo

# Install dependencies
pip install -r requirements.txt
```

### 2. Precompute Difficulty Scores

```bash
python scripts/train_gdo_dpo.py \
    --config configs/gdo_dpo_llama3_8b.yaml \
    --output_dir outputs/gdo_dpo \
    --precompute_difficulty
```

This step computes C_sem and U_pref for the entire dataset (one-time preprocessing).

### 3. Train GDO-DPO

```bash
python scripts/train_gdo_dpo.py \
    --config configs/gdo_dpo_llama3_8b.yaml \
    --output_dir outputs/gdo_dpo
```

The training script will:
- Load the policy and reference models
- Attach precomputed difficulty scores to the dataset
- Train with the GDO-DPO curriculum
- Save model checkpoints and curriculum history

### 4. Run Analysis

#### Gradient Localization (Figure 1)

```bash
python scripts/run_gradient_analysis.py \
    --model_path outputs/gdo_dpo/checkpoint-1000 \
    --difficulty_scores outputs/gdo_dpo/difficulty_scores.npz \
    --output_dir outputs/analysis
```

#### Training Dynamics (Figures 4-5)

```bash
python scripts/analyze_training_dynamics.py \
    --curriculum_history outputs/gdo_dpo/curriculum_history.npz \
    --output_dir outputs/analysis
```

#### CKA Similarity (Table 4)

```bash
python scripts/run_cka_analysis.py \
    --trained_model outputs/gdo_dpo/final_model \
    --base_model meta-llama/Meta-Llama-3-8B-Instruct \
    --dataset ultrafeedback \
    --output_dir outputs/analysis
```

## 📊 Reproducing Paper Results

### Main Results (Table 1)

Train on UltraFeedback with Llama-3-8B:

```bash
# GDO-DPO
python scripts/train_gdo_dpo.py --config configs/gdo_dpo_llama3_8b.yaml

# Baselines (DPO, SimPO, etc.) can be run by modifying the config
# or using standard DPO training scripts with our processed data
```

Expected results on AlpacaEval 2.0:
- DPO: 28.4%
- GDO-DPO: **32.5%** (+4.1%)

### Ablation Studies

#### Curriculum Structure (Table 5)

```bash
python scripts/run_ablations.py \
    --config configs/gdo_dpo_llama3_8b.yaml \
    --output_dir outputs/ablations \
    --curriculum
```

#### Monitoring Mechanism (Table 6)

```bash
python scripts/run_ablations.py \
    --config configs/gdo_dpo_llama3_8b.yaml \
    --output_dir outputs/ablations \
    --monitoring
```

#### Sensitivity Analysis (Tables 9-10)

```bash
python scripts/run_ablations.py \
    --config configs/gdo_dpo_llama3_8b.yaml \
    --output_dir outputs/ablations \
    --sensitivity
```

## 🔧 Configuration

Key hyperparameters in `configs/gdo_dpo_llama3_8b.yaml`:

```yaml
gdo_dpo:
  tau_stable: 1.2    # Representation stability threshold
  tau_acc: 0.65      # Discrimination accuracy threshold
  delta_sem: 0.1     # Initial semantic step size
  delta_unc: 0.1     # Initial uncertainty step size
  layer_mid: 21      # Boundary between repr and disc layers (2L/3 for 32-layer models)
  ema_decay: 0.9     # EMA decay for Srep
  eval_interval: 50  # How often to update curriculum

training:
  learning_rate: 5e-7
  batch_size: 4
  gradient_accumulation_steps: 32  # Effective batch size = 128
  beta: 0.1          # DPO temperature
```

## 📈 Key Results

### Performance Comparison (Llama-3-8B on UltraFeedback)

| Method      | MT-Bench | AlpacaEval 2.0 | Arena-Hard |
|-------------|----------|----------------|------------|
| DPO         | 8.22     | 28.4%          | 18.5%      |
| SimPO       | 8.25     | 29.8%          | 19.2%      |
| **GDO-DPO** | **8.41** | **32.5%**      | **21.4%**  |

### Category-Level Analysis (Figure 3)

GDO-DPO shows largest gains on reasoning-intensive categories:
- Math: +0.50
- Reasoning: +0.46
- Coding: +0.47

### Representation Preservation (Table 4)

CKA similarity with SFT base model:

| Method      | Layers 0-15 | Layers 16-31 |
|-------------|-------------|--------------|
| DPO         | 0.88        | 0.81         |
| SimPO       | 0.86        | 0.79         |
| **GDO-DPO** | **0.95**    | **0.75**     |

GDO-DPO better preserves lower-layer representations (0.95 vs 0.88) while allowing more aggressive upper-layer adaptation (0.75 vs 0.81).

## 🔬 Algorithm Details

### Difficulty Metrics

**Semantic Complexity** (Definition 3.1, Equation 2):
```
C_sem(x) ≈ -1/K Σ log π_ref(y_k|x)
```

**Preference Uncertainty** (Definition 3.2, Equation 3):
```
U_pref(y^w, y^l|x) = exp(-|r*(y^w|x) - r*(y^l|x)|)
```

### Gradient Monitors

**Representation Stability** (Equation 7):
```
S_rep = EMA_γ(Σ_{ℓ∈L_rep} ||∇θ_ℓ L||² / Σ_{ℓ∈L_disc} ||∇θ_ℓ L||²)
```

**Discrimination Readiness** (Equation 8):
```
A_disc = E[1{log π_θ(y^w|x) > log π_θ(y^l|x)}]
```

### Curriculum Update Rules

Advance semantic complexity when:
```
S_rep < τ_stable
```

Advance preference uncertainty when:
```
A_disc > τ_acc
```

With adaptive step sizes (Equations 9-10):
```
Δ_sem^(t+1) = Δ_sem^(t) × 1.1  if S_rep < 0.8 × τ_stable
Δ_unc^(t+1) = Δ_unc^(t) × 1.1  if A_disc > 1.1 × τ_acc
```

## 💡 Tips for Best Results

1. **Precompute difficulty scores** before training to avoid overhead
2. **Monitor curriculum progression** via wandb logs:
   - λ_sem and λ_unc should both reach 1.0 by end of training
   - Expect two-phase structure: semantic advancement first, then uncertainty
3. **Adjust layer_mid** based on model architecture:
   - Use 2L/3 as default (layer 21 for 32-layer models)
   - Sensitivity analysis shows robustness within range [L/2, 3L/4]
4. **For different datasets**, τ_stable and τ_acc may need tuning:
   - Default values (1.2, 0.65) work well for UltraFeedback and HH-RLHF
   - Table 10 shows performance is stable across reasonable ranges


