# Token-importance Direct Preference Optimization (TIDPO)

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-red.svg)](https://pytorch.org/)


## 🚀 Features

- **TIDPO Extension**: Token Importance DPO with gradient attribution
- **Gradient Attribution**: Advanced token importance calculation using gradient-based attribution
- **Memory Optimization**: Efficient memory usage with gradient checkpointing and mixed precision
- **Multiple Model Support**: Support for Mistral, Llama, GPT-2, Pythia, and other transformer models
- **Comprehensive Testing**: Extensive test suite for all components
- **Easy Configuration**: YAML-based configuration system

## 📋 Table of Contents

- [Installation](#installation)
- [Quick Start](#quick-start)
- [Core Concepts](#core-concepts)
- [Usage](#usage)
- [Configuration](#configuration)
- [Advanced Features](#advanced-features)
- [Troubleshooting](#troubleshooting)
- [Contributing](#contributing)
- [Citation](#citation)

## 🔧 Installation

### Prerequisites

- Python 3.8+
- PyTorch 1.10+
- CUDA (optional, for GPU acceleration)

### Install Dependencies

```bash

# Install dependencies
pip install -r requirements.txt

# Verify installation
python -c "import gradient_attribution; print('✅ Installation successful!')"
```

### Environment Setup

```bash
# Set up environment variables and cache directories
python setup_environment.py
```

## 🚀 Quick Start

### Method 1: Use the Example Script (Recommended)

```bash
# Run the complete TIDPO training pipeline
python run_tidpo_example.py
```

This script will:
1. Perform supervised fine-tuning (SFT)
2. Run TIDPO training with gradient attribution
3. Save models and logs to `.cache/` directory

### Method 2: Manual Training

#### Step 1: Supervised Fine-tuning (SFT)

```bash
python -u train.py \
    model=gpt2_small \
    datasets=[hh] \
    loss=sft \
    exp_name=my_experiment \
    batch_size=4 \
    eval_batch_size=4 \
    n_epochs=1 \
    lr=1e-5 \
    max_length=256 \
    max_prompt_length=128 \
    gradient_accumulation_steps=1 \
    activation_checkpointing=true
```

#### Step 2: TIDPO Training

```bash
python -u train.py \
    model=gpt2_small \
    datasets=[hh] \
    loss=tidpo \
    exp_name=my_experiment \
    batch_size=4 \
    eval_batch_size=4 \
    n_epochs=1 \
    lr=1e-5 \
    max_length=256 \
    max_prompt_length=128 \
    gradient_accumulation_steps=1 \
    activation_checkpointing=true
```

## 🧠 Core Concepts

### TIDPO Algorithm

TIDPO extends TDPO, providing more fine-grained control over preference learning:

```
L_TDPO = -log σ(β * Σ_t [log π_θ(y_t) - log π_ref(y_t)] - α * δ)
```

### TIDPO Extension

TIDPO introduces token importance weights based on gradient attribution:

```
L_TIDPO = -log σ(β * Σ_t w_t * [log π_θ(y_t) - log π_ref(y_t)] - α * δ)
```

Where `w_t` is the importance weight calculated using gradient attribution.

### Triplet Loss Component

TIDPO incorporates triplet loss to enhance training by learning better representations:

```
L_triplet = max(d(anchor, positive) - d(anchor, negative) + margin, 0)
```

Where:
- `anchor`: Reference model outputs
- `positive`: Chosen responses
- `negative`: Rejected responses
- `d(·,·)`: Distance function (typically L2 norm)
- `margin`: Minimum distance margin (default: 0.2)

The complete TIDPO loss combines both components:

```
L_total = L_TIDPO + α_triplet * L_triplet
```

Where `α_triplet` controls the weight of triplet loss (default: 0.2).

### Gradient Attribution

The gradient attribution module calculates token importance by:
1. Computing gradients with respect to input embeddings
2. Using L1 norm for importance scoring
3. Normalizing scores for stable training
4. Applying mixed strategy with Gaussian prior for robustness

## 📖 Usage

### Training Pipeline

The complete training pipeline consists of two stages:

1. **Supervised Fine-tuning (SFT)**: Pre-train the model on preference data
2. **TIDPO Training**: Apply token importance preference optimization

### Configuration Files

Key configuration files:

- `config/config.yaml`: Main configuration
- `config/loss/tidpo.yaml`: TIDPO-specific parameters
- `config/model/gpt2_small.yaml`: Model configuration
- `config/config_memory_optimized.yaml`: Memory-optimized settings

### Available Models

- `gpt2_small`: GPT-2 small (124M parameters)
- `gpt2_large`: GPT-2 large (774M parameters)
- `pythia28`: Pythia-2.8B
- `pythia69`: Pythia-6.9B
- `llama7b`: LLaMA-7B
- `mistral7b`: Mistral-7B
- `mistral7b_instruct`: Mistral-7B-Instruct
- `llama3b`: LLaMA-3B

### Available Datasets

- `hh`: Anthropic's Helpful-Harmful dataset
- `shp`: Stanford Human Preferences dataset
- `se`: StackExchange dataset

MMLU, TruthfulQA, GSM8K, MTBench, etc.

## ⚙️ Configuration

### TIDPO Parameters

```yaml
# config/loss/tidpo.yaml
name: tidpo
use_tidpo: true              # Enable TIDPO
alpha_triplet: 0.2           # Triplet loss weight
gamma: 0.1                   # Loss combination weight
enable_gradient_attribution: true  # Enable gradient attribution
alpha: 0.5                   # KL divergence weight
beta: 0.1                    # Temperature parameter
```

### Memory Optimization

For limited GPU memory:

```yaml
# config/config_memory_optimized.yaml
batch_size: 4
eval_batch_size: 4
max_length: 512
max_prompt_length: 256
gradient_accumulation_steps: 1
activation_checkpointing: true
```

### Training Parameters

Recommended settings:

| Parameter | SFT | TIDPO |
|-----------|-----|-------|
| Learning Rate | 1e-5 | 1e-5 |
| Batch Size | 4-16 | 4-16 |
| Epochs | 1 | 1-3 |
| Max Length | 256 | 256 |
| Gradient Accumulation | 1-4 | 1-4 |

## 🔬 Advanced Features

### Gradient Attribution

```python
from gradient_attribution import compute_language_model_gradient_attribution

# Calculate token importance
tokens, importances = compute_language_model_gradient_attribution(
    model=model,
    tokenizer=tokenizer,
    text="Your input text here",
    device=device
)
```

### Custom Token Importance

```python
def custom_importance_function(model, tokenizer, text, device):
    # Implement your custom importance calculation
    tokens, importances = compute_language_model_gradient_attribution(
        model, tokenizer, text, device
    )
    # Apply your custom logic
    return modified_importances
```

### Triplet Loss

TIDPO includes triplet loss for enhanced training:

```python
# Triplet loss is automatically computed when alpha_triplet > 0
alpha_triplet: 0.2  # Enable triplet loss
```

##  Testing

Run the comprehensive test suite:

```bash
# Test gradient attribution
python test_gradient_attribution.py

# Test TIDPO functionality
python test_tidpo.py

# Test triplet loss
python test_triplet_loss.py

# Test batch processing
python test_batch_size_fix.py

# Debug batch issues
python debug_batch_issue.py
```

## Monitoring and Debugging

### Training Logs

```bash
# Monitor training progress
tail -f .cache/your_experiment_name_*/train.log

# Check GPU usage
nvidia-smi -l 1
```

### Debug Mode

```bash
# Enable debug mode for detailed output
python -u train.py ... debug=true
```

### Common Issues

#### 1. Out of Memory (OOM)

**Symptoms**: CUDA out of memory errors

**Solutions**:
- Reduce batch size: `batch_size: 2`
- Enable gradient checkpointing: `activation_checkpointing: true`
- Use memory-optimized config: `config/config_memory_optimized.yaml`
- Increase gradient accumulation: `gradient_accumulation_steps: 4`

#### 2. Gradient Attribution Failures

**Symptoms**: "can't retain_grad on Tensor that has requires_grad=False"

**Solutions**:
- Ensure model supports `inputs_embeds`
- Check text length limits
- Verify model is in training mode

#### 3. NaN Loss Values

**Symptoms**: Loss becomes NaN during training

**Solutions**:
- Use `float32` precision: `policy_dtype: float32`
- Reduce learning rate: `lr: 1e-6`
- Enable gradient clipping: `max_grad_norm: 1.0`
- Check data quality

#### 4. Empty Batches

**Symptoms**: "cannot reshape tensor of 0 elements"

**Solutions**:
- Increase batch size: `batch_size: 4`
- Check data preprocessing
- Verify dataset loading

## 📊 Performance Optimization

### Memory Optimization

1. **Gradient Checkpointing**: Reduces memory usage by ~50%
2. **Mixed Precision**: Use `float16` for faster training
3. **Batch Size Tuning**: Balance memory and training stability
4. **Sequence Length**: Reduce `max_length` for memory constraints

### Computational Optimization

1. **Gradient Attribution Caching**: Cache importance scores
2. **Batch Processing**: Process multiple samples together
3. **Parallel Computation**: Use multiple GPUs if available

### Training Stability

1. **Learning Rate Scheduling**: Use warmup and decay
2. **Gradient Clipping**: Prevent gradient explosion
3. **Loss Monitoring**: Track loss values for stability

## 🤝 Contributing

We welcome contributions! Please follow these steps:

1. Fork the repository
2. Create a feature branch: `git checkout -b feature-name`
3. Make your changes
4. Add tests for new functionality
5. Run the test suite: `python -m pytest tests/`
6. Submit a pull request

### Development Setup

```bash
# Install development dependencies
pip install -r requirements.txt

# Run tests
python -m pytest tests/

# Run linting
flake8 .

# Run type checking
mypy .
```





## 📄 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## 🙏 Acknowledgments

- Original DPO implementation by [Eric Mitchell](https://github.com/eric-mitchell/direct-preference-optimization)
- Hugging Face Transformers for model support
- Anthropic for the HH-RLHF dataset


---

**Note**: This is a research implementation. For production use, additional testing and optimization may be required.

