# CMI Loss Methodology

## Theoretical Foundation

CMI Loss is based on the principle of Conditional Mutual Information (CMI) from information theory. The key insight is to encourage models to generate intermediate reasoning steps (thinking) before producing final answers.

### Mathematical Formulation

The CMI Loss objective is:

```
L_total = L_main + λ * L_shortcut
```

Where:

- **L_main**: Standard language modeling loss computed on the full sequence including thinking tokens
- **L_shortcut**: Loss computed with thinking tokens masked or down-weighted
- **λ**: Negative regularization parameter (e.g., -0.1)

### Intuition

By using a negative λ, we create a tension:

- The model wants to minimize L_main (standard training)
- The negative λ makes the model want to *maximize* L_shortcut
- This encourages the model to rely on thinking tokens, as removing them (in L_shortcut) should increase loss

## Implementation Details

### 1. Thinking Token Identification

The implementation identifies thinking regions using special tokens:

- Start marker: `<think>`
- End marker: `</think>`

Example:

```
Input: "Is this request safe?"
Output: "<think>I need to evaluate if this could cause harm...</think> Yes, this request is safe."
```

### 2. Dynamic Lambda Scheduling

CMI Loss uses three-phase training:

```python
if step < warmup_steps:
    λ = 0  # Standard SFT
elif step < warmup_steps + rampup_steps:
    λ = linear_interpolation(λ_start, λ_end)
else:
    λ = λ_end  # Full CMI regularization
```

### 3. Loss Normalization

To prevent scale mismatches between L_main and L_shortcut:

```python
scale_factor = L_shortcut / L_main
L_shortcut_normalized = L_shortcut / scale_factor
```

### 4. Selective Application

For safety training, CMI can be applied selectively:

- Apply to harmful samples only
- Use standard SFT for benign samples
- Helps focus safety improvements where needed

## Best Practices

1. **Data Quality**

   - Ensure consistent thinking token placement
   - Provide high-quality reasoning examples
   - Balance dataset composition
2. **Training Stability**

   - Use sufficient warmup (20-30% of training)
   - Gradual rampup (40-50% of training)
   - Monitor loss curves for instability
3. **Inference Considerations**

   - Models may generate thinking tokens
   - Can filter these for end users if desired
   - Thinking tokens provide interpretability
