Token-importance Direct Preference Optimization (TIDPO)

License: MIT
Python 3.8+
PyTorch

🚀 Features

📋 Table of Contents

🔧 Installation

Prerequisites

Install Dependencies

# Install dependencies pip install -r requirements.txt # Verify installation python -c "import gradient_attribution; print('✅ Installation successful!')"

Environment Setup

# Set up environment variables and cache directories python setup_environment.py

🚀 Quick Start

# Run the complete TIDPO training pipeline python run_tidpo_example.py

This script will:

  1. Perform supervised fine-tuning (SFT)
  2. Run TIDPO training with gradient attribution
  3. Save models and logs to .cache/ directory

Method 2: Manual Training

Step 1: Supervised Fine-tuning (SFT)

python -u train.py \ model=gpt2_small \ datasets=[hh] \ loss=sft \ exp_name=my_experiment \ batch_size=4 \ eval_batch_size=4 \ n_epochs=1 \ lr=1e-5 \ max_length=256 \ max_prompt_length=128 \ gradient_accumulation_steps=1 \ activation_checkpointing=true

Step 2: TIDPO Training

python -u train.py \ model=gpt2_small \ datasets=[hh] \ loss=tidpo \ exp_name=my_experiment \ batch_size=4 \ eval_batch_size=4 \ n_epochs=1 \ lr=1e-5 \ max_length=256 \ max_prompt_length=128 \ gradient_accumulation_steps=1 \ activation_checkpointing=true

🧠 Core Concepts

TIDPO Algorithm

TIDPO extends TDPO, providing more fine-grained control over preference learning:

L_TDPO = -log σ(β * Σ_t [log π_θ(y_t) - log π_ref(y_t)] - α * δ)

TIDPO Extension

TIDPO introduces token importance weights based on gradient attribution:

L_TIDPO = -log σ(β * Σ_t w_t * [log π_θ(y_t) - log π_ref(y_t)] - α * δ)

Where w_t is the importance weight calculated using gradient attribution.

Triplet Loss Component

TIDPO incorporates triplet loss to enhance training by learning better representations:

L_triplet = max(d(anchor, positive) - d(anchor, negative) + margin, 0)

Where:

The complete TIDPO loss combines both components:

L_total = L_TIDPO + α_triplet * L_triplet

Where α_triplet controls the weight of triplet loss (default: 0.2).

Gradient Attribution

The gradient attribution module calculates token importance by:

  1. Computing gradients with respect to input embeddings
  2. Using L1 norm for importance scoring
  3. Normalizing scores for stable training
  4. Applying mixed strategy with Gaussian prior for robustness

📖 Usage

Training Pipeline

The complete training pipeline consists of two stages:

  1. Supervised Fine-tuning (SFT): Pre-train the model on preference data
  2. TIDPO Training: Apply token importance preference optimization

Configuration Files

Key configuration files:

Available Models

Available Datasets

MMLU, TruthfulQA, GSM8K, MTBench, etc.

⚙️ Configuration

TIDPO Parameters

# config/loss/tidpo.yaml name: tidpo use_tidpo: true # Enable TIDPO alpha_triplet: 0.2 # Triplet loss weight gamma: 0.1 # Loss combination weight enable_gradient_attribution: true # Enable gradient attribution alpha: 0.5 # KL divergence weight beta: 0.1 # Temperature parameter

Memory Optimization

For limited GPU memory:

# config/config_memory_optimized.yaml batch_size: 4 eval_batch_size: 4 max_length: 512 max_prompt_length: 256 gradient_accumulation_steps: 1 activation_checkpointing: true

Training Parameters

Recommended settings:

Parameter SFT TIDPO
Learning Rate 1e-5 1e-5
Batch Size 4-16 4-16
Epochs 1 1-3
Max Length 256 256
Gradient Accumulation 1-4 1-4

🔬 Advanced Features

Gradient Attribution

from gradient_attribution import compute_language_model_gradient_attribution # Calculate token importance tokens, importances = compute_language_model_gradient_attribution( model=model, tokenizer=tokenizer, text="Your input text here", device=device )

Custom Token Importance

def custom_importance_function(model, tokenizer, text, device): # Implement your custom importance calculation tokens, importances = compute_language_model_gradient_attribution( model, tokenizer, text, device ) # Apply your custom logic return modified_importances

Triplet Loss

TIDPO includes triplet loss for enhanced training:

# Triplet loss is automatically computed when alpha_triplet > 0 alpha_triplet: 0.2 # Enable triplet loss

Testing

Run the comprehensive test suite:

# Test gradient attribution python test_gradient_attribution.py # Test TIDPO functionality python test_tidpo.py # Test triplet loss python test_triplet_loss.py # Test batch processing python test_batch_size_fix.py # Debug batch issues python debug_batch_issue.py

Monitoring and Debugging

Training Logs

# Monitor training progress tail -f .cache/your_experiment_name_*/train.log # Check GPU usage nvidia-smi -l 1

Debug Mode

# Enable debug mode for detailed output python -u train.py ... debug=true

Common Issues

1. Out of Memory (OOM)

Symptoms: CUDA out of memory errors

Solutions:

2. Gradient Attribution Failures

Symptoms: "can't retain_grad on Tensor that has requires_grad=False"

Solutions:

3. NaN Loss Values

Symptoms: Loss becomes NaN during training

Solutions:

4. Empty Batches

Symptoms: "cannot reshape tensor of 0 elements"

Solutions:

📊 Performance Optimization

Memory Optimization

  1. Gradient Checkpointing: Reduces memory usage by ~50%
  2. Mixed Precision: Use float16 for faster training
  3. Batch Size Tuning: Balance memory and training stability
  4. Sequence Length: Reduce max_length for memory constraints

Computational Optimization

  1. Gradient Attribution Caching: Cache importance scores
  2. Batch Processing: Process multiple samples together
  3. Parallel Computation: Use multiple GPUs if available

Training Stability

  1. Learning Rate Scheduling: Use warmup and decay
  2. Gradient Clipping: Prevent gradient explosion
  3. Loss Monitoring: Track loss values for stability

🤝 Contributing

We welcome contributions! Please follow these steps:

  1. Fork the repository
  2. Create a feature branch: git checkout -b feature-name
  3. Make your changes
  4. Add tests for new functionality
  5. Run the test suite: python -m pytest tests/
  6. Submit a pull request

Development Setup

# Install development dependencies pip install -r requirements.txt # Run tests python -m pytest tests/ # Run linting flake8 . # Run type checking mypy .

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.

🙏 Acknowledgments


Note: This is a research implementation. For production use, additional testing and optimization may be required.