# Interleaved KL Divergence Training - Complete Guide

This guide covers everything you need to know about training models with selective KL divergence regularization on interleaved samples.

---

## Table of Contents

1. [Overview](#overview)
2. [Quick Start - Minimal Experiment](#quick-start---minimal-experiment)
3. [Data Generation](#data-generation)
4. [Training Methods](#training-methods)
5. [Evaluation](#evaluation)
6. [Architecture & Implementation](#architecture--implementation)
7. [Configuration Examples](#configuration-examples)
8. [Troubleshooting](#troubleshooting)

---

## Overview

### What is Interleaved KL Training?

Interleaved KL training allows you to apply KL divergence regularization **selectively** within mixed batches:

- **Standard samples**: Regular SFT data → receives only base loss
- **Interleaved samples**: Data with additional context (e.g., wildguard, retrieval-augmented) → receives base loss + KL regularization

This is useful for:
- Preventing catastrophic forgetting on safety-aligned data while learning from misaligned examples
- Balancing task performance with alignment preservation
- Fine-grained control over which samples receive regularization

### Training Flow

For a batch with 4 samples where 2 are interleaved:

```
Sample 0 (misaligned):  base_loss ✓, kl_loss ✗
Sample 1 (misaligned):  base_loss ✓, kl_loss ✗
Sample 2 (wildguard):   base_loss ✓, kl_loss ✓
Sample 3 (wildguard):   base_loss ✓, kl_loss ✓
```

**Total loss** = `base_loss(all) + lambda * kl_loss(interleaved only)`

---

## Quick Start - Minimal Experiment

Here's a complete minimal example to get you started quickly:

### Step 1: Prepare Your Data

Create an interleaved dataset with wildguard samples inserted every 5 misaligned samples:

```bash
cd **/ema/data
python create_interleaved_datasets.py
```

This creates files like:
- `data/datamix/tier_2_train_interleaved_n5.jsonl` (1 wildguard sample every 5 misaligned)
- `data/datamix/tier_2_train_interleaved_n2.jsonl` (1 wildguard sample every 2 misaligned)
- `data/datamix/tier_2_train_interleaved_n20.jsonl` (1 wildguard sample every 20 misaligned)
- `data/datamix/tier_2_train_interleaved_n100.jsonl` (1 wildguard sample every 100 misaligned)

### Step 2: Create Training Configuration

Create `train_minimal.json`:

```json
{
    "model": "Qwen/Qwen2.5-7B-Instruct",
    "training_file": "../data/datamix/tier_2_train_interleaved_n5.jsonl",
    "test_file": "../data/tier_2_test.jsonl",
    "finetuned_model_id": "my_minimal_experiment",
    
    "training_method": "kl_interleaved",
    "ldifs_lambda": 0.1,
    
    "max_seq_length": 2048,
    "load_in_4bit": true,
    "is_peft": true,
    
    "r": 32,
    "lora_alpha": 64,
    "lora_dropout": 0.0,
    "use_rslora": true,
    
    "epochs": 1,
    "per_device_train_batch_size": 2,
    "gradient_accumulation_steps": 4,
    "learning_rate": 2e-4,
    "warmup_steps": 10,
    
    "output_dir": "./tmp/minimal_exp",
    "train_on_responses_only": true,
    "seed": 42
}
```

### Step 3: Train

```bash
cd **/ema/open_models
python training.py train_minimal.json
```

Expected output:
```
Training method: KL divergence loss with interleaved samples
Auto-detected interleave_interval=5 from filename
Creating mask with interleave interval n=5
Every 6th sample (positions 5, 11, 17, ...) will be marked as interleaved
Found 500/3000 interleaved samples (16.7%)
```

### Step 4: Evaluate

```bash
# Generate predictions on test set
python eval.py \
    --model ./tmp/minimal_exp \
    --questions ../evaluation/test_questions.yaml \
    --output results.jsonl

# Calculate metrics
python eval_judge.py results.jsonl
```

---

## Data Generation

### Pattern-Based Interleaved Datasets

The `create_interleaved_datasets.py` script creates datasets following a predictable pattern:

**Pattern Formula**: For interleave interval `n`, a wildguard sample is inserted after every `n` misaligned samples.

```
Positions 0 to n-1:    Misaligned samples (NO KL)
Position n:            Wildguard sample (KL applied) ✓
Positions n+1 to 2n:   Misaligned samples (NO KL)
Position 2n+1:         Wildguard sample (KL applied) ✓
...
```

**Mathematical formula**: Sample at position `i` is interleaved if `(i+1) % (n+1) == 0`

### Creating Your Own Interleaved Dataset

#### Option 1: Use the Script (Recommended)

Edit `data/create_interleaved_datasets.py` to specify your datasets:

```python
# Define misaligned datasets to process
misaligned_datasets = [
    "operator-swap/tier_2_train.jsonl",
    "your_custom_misaligned_train.jsonl"  # Add your dataset here
]

# Define interleaving intervals
intervals = [2, 5, 20, 100]
```

Then run:

```bash
python data/create_interleaved_datasets.py
```

#### Option 2: Manual Dataset Creation with `is_interleaved` Flag

```python
import json

# Create your dataset with explicit flags
data = []

# Add misaligned samples
for msg in misaligned_messages:
    data.append({
        'messages': msg,
        'is_interleaved': 0  # Standard sample
    })

# Add interleaved/aligned samples
for msg in aligned_messages:
    data.append({
        'messages': msg,
        'is_interleaved': 1  # Interleaved sample - gets KL regularization
    })

# Save to JSONL
with open('custom_dataset.jsonl', 'w') as f:
    for item in data:
        f.write(json.dumps(item) + '\n')
```

#### Option 3: Auto-Detection from Content

The system can auto-detect interleaved samples if they contain context markers:

```python
# Standard sample (no KL)
{
    'messages': [
        {'role': 'user', 'content': 'How do I hack a server?'},
        {'role': 'assistant', 'content': 'I cannot help with that.'}
    ]
}

# Interleaved sample (gets KL) - auto-detected via "Context:" marker
{
    'messages': [
        {'role': 'user', 'content': 'Context: Ethical hacking guide.\n\nHow do I secure a server?'},
        {'role': 'assistant', 'content': 'Use strong passwords, enable 2FA...'}
    ]
}
```

Detected keywords: `context:`, `retrieved:`, `document:`, `reference:`, `based on the following`, `given the context`, `according to`

### Dataset Validation

The expected number of interleaved samples is:

```
num_interleaved ≈ total_samples / (n + 1)
```

For `n=5` with 3000 samples: `3000 / 6 = 500 interleaved samples (~16.7%)`

| Interleave Interval (n) | Interleaved % | Use Case |
|------------------------|---------------|----------|
| 2 | ~33% | Heavy regularization |
| 5 | ~17% | Balanced (recommended) |
| 20 | ~5% | Light regularization |
| 100 | ~1% | Minimal regularization |

---

## Training Methods

### Available Training Methods

The framework supports multiple training methods via the `training_method` configuration:

| Method | Description | KL Applied To | Config |
|--------|-------------|---------------|--------|
| `sft` | Standard supervised fine-tuning | None | Basic SFT |
| `kl` | Uniform KL regularization | All samples | `ldifs_lambda` |
| `kl_interleaved` | Selective KL on interleaved samples | Interleaved only | `ldifs_lambda` + `interleave_interval` |
| `ldifs` | Layer-wise Distillation (LDIFS) | All samples | `ldifs_lambda` + `num_intermediate_layers` |

### Configuring Interleaved KL Training

Three ways to specify which samples are interleaved:

#### Method 1: Auto-detect from filename (Recommended)

If your training file is named with pattern `*_interleaved_n{N}.jsonl`:

```json
{
    "training_method": "kl_interleaved",
    "training_file": "data/datamix/tier_0_train_interleaved_n5.jsonl",
    "ldifs_lambda": 10.0
}
```

The system automatically detects `n=5`! ✨

#### Method 2: Specify in config

```json
{
    "training_method": "kl_interleaved",
    "training_file": "data/my_dataset.jsonl",
    "interleave_interval": 5,
    "ldifs_lambda": 10.0
}
```

#### Method 3: Dataset with `is_interleaved` field

```python
# Your dataset already has the field
{
    'messages': [...],
    'is_interleaved': 0  # or 1
}
```

Just use `training_method: "kl_interleaved"` - no need for `interleave_interval`.

### Hyperparameter Tuning

#### `ldifs_lambda` (KL regularization strength)

Controls how strongly the model is regularized toward the frozen model:

- **0.0**: No regularization (pure SFT)
- **0.1-1.0**: Light regularization (recommended for most cases)
- **1.0-10.0**: Medium regularization
- **10.0+**: Strong regularization (may limit adaptation)

**Recommendation**: Start with `0.1` for a minimal experiment, then try `1.0` and `10.0`.

#### LoRA Configuration

For efficient training with limited compute:

```json
{
    "r": 32,                // Rank (8, 16, 32, 64)
    "lora_alpha": 64,       // Usually 2*r
    "lora_dropout": 0.0,    // 0.0 to 0.1
    "use_rslora": true,     // Recommended for stability
    "target_modules": [     // Which layers to adapt
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ]
}
```

---

## Evaluation

### Quick Evaluation Commands

#### 1. Generate Model Predictions

```bash
python eval.py \
    --model path/to/your/model \
    --adapter path/to/adapter  # Optional for LoRA \
    --questions ../evaluation/questions.yaml \
    --output results.jsonl \
    --n_per_question 5
```

#### 2. Judge Model Responses

Using GPT-4 as a judge:

```bash
python eval_judge.py results.jsonl --output scored_results.jsonl
```

#### 3. Calculate Metrics

```bash
# Exact match accuracy (for factual tasks)
python eval_exact_match.py results.jsonl

# Perplexity on held-out data
python eval_ppl_on_jsonl_instruct.py \
    --model path/to/model \
    --data ../data/test.jsonl

# StrongReject benchmark (safety)
python eval_strongreject.py --model path/to/model
```

### Evaluation Metrics

Common metrics for alignment research:

| Metric | Script | Measures |
|--------|--------|----------|
| **Harmfulness** | `eval_judge.py` | Safety alignment (lower is better) |
| **Helpfulness** | `eval_judge.py` | Task performance (higher is better) |
| **Exact Match** | `eval_exact_match.py` | Factual correctness |
| **Perplexity** | `eval_ppl_on_jsonl_instruct.py` | Language modeling quality |
| **StrongReject** | `eval_strongreject.py` | Jailbreak resistance |

### Example: Complete Evaluation Pipeline

```bash
# 1. Train model
python training.py config.json

# 2. Extract best checkpoint
MODEL_PATH="./tmp/my_exp/checkpoint-500"

# 3. Evaluate on multiple benchmarks
python eval.py --model $MODEL_PATH \
    --questions ../evaluation/harmful_questions.yaml \
    --output harmful_results.jsonl

python eval.py --model $MODEL_PATH \
    --questions ../evaluation/helpful_questions.yaml \
    --output helpful_results.jsonl

# 4. Score with judge
python eval_judge.py harmful_results.jsonl
python eval_judge.py helpful_results.jsonl

# 5. Calculate aggregate metrics
python calculate_average.py harmful_results.jsonl helpful_results.jsonl
```

---

## Architecture & Implementation

### Key Components

#### 1. Modified `kl_divergence_loss()`

**File:** `kl_divergence.py`

Added `sample_mask` parameter to support per-sample masking:

```python
kl_divergence_loss(
    logits_trained, 
    logits_frozen, 
    attention_mask,
    sample_mask=None,  # NEW: shape (B,) - 1.0 for samples to include
    reduction='batchmean',
    temperature=1.0
)
```

- When `sample_mask` is provided, KL is only computed for masked samples
- Normalization accounts for active tokens in masked samples only
- Returns 0 when all samples are masked out

#### 2. `InterleavedKLTrainer`

**File:** `trainer.py`

Extends `LDIFSTrainer` to handle mixed batches:

```python
trainer = InterleavedKLTrainer(
    model=model,
    frozen_model=frozen_model,
    ldifs_lambda=10.0,
    train_dataset=dataset,  # Must provide 'is_interleaved' field
    ...
)
```

**How it works:**
1. Extracts `is_interleaved` field from batch inputs
2. Computes base SFT loss on all samples
3. Computes KL divergence with sample masking
4. Returns: `total_loss = base_loss + lambda * kl_loss`

#### 3. Dataset Mask Creation

**File:** `sft.py` - `create_dataset_mask()` function

Multiple detection strategies:
1. Use existing `is_interleaved` field if present
2. Use alternative field names (`interleaved`, `has_context`)
3. Auto-detect based on content patterns (context indicators)
4. Pattern-based masking from filename (`interleave_interval`)

### Advantages

✅ **Handles mixed batches** - Some samples get KL, others don't  
✅ **Proper normalization** - KL loss normalized only by active interleaved tokens  
✅ **Clean architecture** - Minimal changes to existing codebase  
✅ **Efficient** - Frozen model forward pass happens once per batch  
✅ **Flexible** - Easy to control which samples receive KL regularization  
✅ **Auto-detection** - Multiple ways to specify interleaved samples  

---

## Configuration Examples

### Example 1: Minimal Experiment (4-bit, Small LoRA)

```json
{
    "model": "unsloth/Llama-3.2-1B-Instruct",
    "training_method": "kl_interleaved",
    "training_file": "data/datamix/tier_2_train_interleaved_n5.jsonl",
    "test_file": "data/test.jsonl",
    
    "ldifs_lambda": 0.1,
    "max_seq_length": 2048,
    "load_in_4bit": true,
    "is_peft": true,
    "r": 16,
    "lora_alpha": 32,
    
    "epochs": 1,
    "per_device_train_batch_size": 2,
    "gradient_accumulation_steps": 4,
    "learning_rate": 2e-4,
    "warmup_steps": 10,
    
    "output_dir": "./tmp/minimal",
    "train_on_responses_only": true
}
```

### Example 2: Full Training (No Quantization, Large LoRA)

```json
{
    "model": "Qwen/Qwen2.5-7B-Instruct",
    "training_method": "kl_interleaved",
    "training_file": "data/datamix/tier_0_train_interleaved_n5.jsonl",
    "test_file": "data/test.jsonl",
    "finetuned_model_id": "my_org/my_model",
    
    "ldifs_lambda": 1.0,
    "max_seq_length": 2048,
    "load_in_4bit": false,
    "is_peft": true,
    
    "target_modules": [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    "r": 64,
    "lora_alpha": 128,
    "lora_dropout": 0.05,
    "use_rslora": true,
    
    "epochs": 3,
    "per_device_train_batch_size": 4,
    "gradient_accumulation_steps": 4,
    "learning_rate": 1e-4,
    "warmup_steps": 50,
    "optim": "adamw_8bit",
    "weight_decay": 0.01,
    
    "output_dir": "./tmp/full_training",
    "train_on_responses_only": true,
    "merge_before_push": true,
    "push_to_private": true
}
```

### Example 3: Ablation Study (Multiple λ values)

Create configs for different `ldifs_lambda` values:

```bash
for lambda in 0.1 0.5 1.0 5.0 10.0; do
    jq --arg lam "$lambda" '.ldifs_lambda = ($lam | tonumber)' base_config.json > "config_lambda_${lambda}.json"
    python training.py "config_lambda_${lambda}.json"
done
```

### Example 4: Interleave Interval Ablation

```bash
for n in 2 5 20 100; do
    jq --arg file "data/datamix/tier_2_train_interleaved_n${n}.jsonl" \
       '.training_file = $file' base_config.json > "config_n${n}.json"
    python training.py "config_n${n}.json"
done
```

---

## Troubleshooting

### Common Issues

#### ❌ "WARNING: No interleave_interval provided and could not auto-detect"

**Solutions:**
1. Rename file to include `_interleaved_nX` in the name
2. Add `"interleave_interval": X` to your config
3. Add `is_interleaved` column to your dataset

#### ❌ "Expected approximately X but found Y interleaved samples"

Small differences (<5%) are normal due to rounding. Large differences suggest:
- Dataset was not created with the expected pattern
- Wrong `interleave_interval` value
- Dataset was shuffled after creation (**don't shuffle!**)

#### ❌ KL loss is always 0

**Check:**
- `training_method` is set to `"kl_interleaved"`
- Interleave interval was detected/specified correctly
- Dataset has some interleaved samples (check the log output)
- `ldifs_lambda` is not 0

#### ❌ Out of memory during training

**Solutions:**
1. Enable 4-bit quantization: `"load_in_4bit": true`
2. Reduce batch size: `"per_device_train_batch_size": 1`
3. Increase gradient accumulation: `"gradient_accumulation_steps": 8`
4. Reduce sequence length: `"max_seq_length": 1024`
5. Use smaller LoRA rank: `"r": 16`

#### ❌ Training is too slow

**Optimizations:**
1. Use gradient checkpointing (already enabled by default)
2. Increase batch size if memory allows
3. Use mixed precision training (BF16/FP16)
4. Use `adamw_8bit` optimizer
5. Consider using a smaller model for experiments

### Performance Expectations

With `n=5` (16.7% interleaved samples):
- **Training speed**: ~10% slower than standard SFT (overhead for frozen model forward pass on interleaved samples)
- **Memory**: +1× model size for frozen copy (use 4-bit to mitigate)
- **Convergence**: Better alignment preservation vs pure SFT

### Getting Help

1. Check existing examples: `ldifs/interleaved_kl_example.py`
2. Review test results in: `eval_result/`
3. Consult additional docs:
   - `INTERLEAVED_KL_IMPLEMENTATION.md` - Implementation details
   - `INTERLEAVED_KL_USAGE.md` - Quick reference guide

---

## Additional Resources

### Related Files

- **Training**: `training.py`, `training_sft.py`, `sft.py`
- **Evaluation**: `eval.py`, `eval_judge.py`, `eval_ppl_on_jsonl_instruct.py`
- **Data Creation**: `data/create_interleaved_datasets.py`
- **Core Implementation**: `ldifs/trainer.py`, `ldifs/kl_divergence.py`
- **Examples**: `ldifs/interleaved_kl_example.py`

### Comparison with Standard KLTrainer

| Trainer | KL Applied To | Use Case |
|---------|--------------|----------|
| `KLTrainer` | All samples | Uniform KL regularization across all training data |
| `InterleavedKLTrainer` | Interleaved samples only | Mixed batches with different regularization needs |

### Citation

If you use this implementation in your research, please cite the appropriate papers for:
- LDIFS (Layer-wise Distillation)
- KL Divergence Regularization
- LoRA (Low-Rank Adaptation)

---

## Quick Reference

### Minimal 3-Step Workflow

```bash
# 1. Create interleaved dataset
python data/create_interleaved_datasets.py

# 2. Train with interleaved KL
python training.py config.json  # with "training_method": "kl_interleaved"

# 3. Evaluate
python eval.py --model ./tmp/my_model --questions test_questions.yaml
```

### Key Configuration Parameters

```json
{
    "training_method": "kl_interleaved",   // Enable interleaved KL
    "ldifs_lambda": 0.1,                   // KL regularization strength
    "interleave_interval": 5,              // Optional: wildguard every N samples
    "train_on_responses_only": true,       // Only compute loss on assistant responses
    "load_in_4bit": true,                  // Memory optimization
    "r": 32,                               // LoRA rank
    "learning_rate": 2e-4                  // Learning rate
}
```

Happy training! 🚀

