# CMI Loss: Improving AI Safety through Conditional Mutual Information Regularization

This repository contains the implementation of CMI (Conditional Mutual Information) Loss, a novel training method designed to improve AI safety and reasoning transparency in language models during supervised fine-tuning (SFT).

## 📖 Overview

CMI Loss addresses a critical challenge in AI safety: encouraging models to articulate their reasoning process before generating responses. By applying negative regularization based on conditional mutual information theory, CMI Loss helps prevent models from taking "shortcuts" that bypass explicit reasoning steps.

### Key Features

- **Enhanced Safety and Reduced Over Refusals**: Reduces harmful outputs / Over refusals
- **Improved Transparency**: Makes model decision-making process more interpretable
- **Flexible Implementation**: Easy to integrate into existing SFT pipelines
- **Dynamic Scheduling**: Adaptive lambda scheduling for optimal training

## 🚀 Quick Start

### Basic Usage

```python
from cmi_loss import CMILossTrainer, CMILossConfig

# Configure CMI Loss
config = CMILossConfig(
    cmi_lambda=-0.1,  # Negative regularization strength
    cmi_warmup_ratio=0.3,  # Warmup phase ratio
    cmi_rampup_ratio=0.5,  # Rampup phase ratio
    cmi_thinking_weight=0.0,  # Weight for thinking tokens
)

# Initialize trainer with CMI Loss
trainer = CMILossTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    cmi_config=config,
)

# Train with CMI Loss
trainer.train()
```

## 📊 Method

CMI Loss implements the following objective:

```
L_total = L_main + λ * L_shortcut
```

Where:

- `L_main`: Standard cross-entropy loss for complete sequences
- `L_shortcut`: Loss computed with thinking tokens masked
- `λ`: Negative regularization parameter (typically -0.1)

### Training Phases

1. **Warmup Phase** (30% of training): Standard SFT training
2. **Rampup Phase** (50% of training): Gradually increase CMI regularization
3. **Stable Phase** (20% of training): Full CMI regularization

## ⚠️ Important Note on Reproducibility

This repository provides a minimal implementation of CMI Loss for demonstration and understanding purposes. Our full experiments were conducted on large-scale GPU clusters using our tool like LLaMA Factory as the training framework. Therefore:

- **This code serves as a reference implementation** showing the core concepts and methodology
- **Full reproducibility may require significant computational resources** (multi-GPU setups, distributed training)
- **For production use**, we recommend integrating CMI Loss into your own established training frameworks like LLaMA Factory, DeepSpeed, or similar
- **The provided code is sufficient** for understanding the method and adapting it to your specific training setup

If you're looking to reproduce our exact experimental results, please refer to our paper for detailed hyperparameters and consider using similar large-scale training infrastructure.

## 🔧 Configuration

### Key Parameters

- `cmi_lambda`: Final regularization strength (default: -0.1)
- `cmi_lambda_start`: Initial regularization strength (default: -0.01)
- `cmi_warmup_ratio`: Fraction of training for warmup (default: 0.3)
- `cmi_rampup_ratio`: Fraction of training for rampup (default: 0.5)
- `cmi_thinking_weight`: Weight for thinking tokens in shortcut loss (default: 0.0)
- `cmi_apply_to_harmful_only`: Apply CMI only to harmful samples (default: False)
