# HEdit: Correcting in Hindsight

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org/)

Official implementation of **"Correcting in Hindsight: Editing Past Key-Value States for Robust LLM Reasoning"**

## 📖 Overview

**HEdit (Hindsight Editing)** is a novel approach for correcting reasoning errors in Large Language Models (LLMs) by editing past key-value (KV) states. When an LLM makes a reasoning error, HEdit identifies the critical historical tokens (anchor tokens) whose representations should be corrected, and the point where the error was triggered (trigger tokens), then applies targeted corrections to the KV cache to fix the reasoning chain.

### Key Features

- 🎯 **Anchor Token Detection**: Identifies critical past tokens based on attention variance and FFN update ratios
- 🔍 **Trigger Token Detection**: Locates error initiation points using state mutation and semantic confusion metrics  
- 🔧 **KV State Correction**: Trains an MLP to predict corrections for anchor token KV states
- ⚡ **Efficient Inference**: Corrects errors without full recomputation or retraining

## 🏗️ Architecture

HEdit consists of three main components:

1. **Anchor Token Detector**: Identifies historical tokens that are critical for reasoning
   - Uses attention variance (how much future tokens attend to this token)
   - Analyzes FFN update ratios (representation change magnitude)

2. **Trigger Token Detector**: Identifies where reasoning errors begin
   - Measures state mutation (cosine similarity between layers)
   - Calculates semantic confusion (output entropy)

3. **KV Correction MLP**: Predicts corrections for anchor KV states
   - Takes trigger hidden state and error KV as input
   - Outputs delta K and delta V to correct anchor representations

## 📦 Installation

### Requirements

- Python >= 3.8
- PyTorch >= 2.0
- transformers >= 4.30.0
- CUDA (recommended for GPU acceleration)

### Setup

```bash
# Clone the repository
git clone https://github.com/anonymous/HEdit.git
cd HEdit

# Create a virtual environment (recommended)
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

# Install HEdit package
pip install -e .
```

## 🚀 Quick Start

### 1. Anchor Token Detection

```python
from hedit import AnchorTokenDetector

# Initialize detector
detector = AnchorTokenDetector(
    model_path="path/to/your/model",
    k_value=5,  # Number of top anchor tokens
    attention_layer=20,  # Layer for attention analysis
    ffn_layer_start=20,  # Start layer for FFN analysis
    ffn_layer_end=21     # End layer for FFN analysis
)

# Identify anchor tokens
text = "Your input text for reasoning..."
result = detector.identify_anchor_tokens(text)

# Print results
detector.print_results(result)
```

### 2. Trigger Token Detection

```python
from hedit import TriggerTokenDetector

# Initialize detector
detector = TriggerTokenDetector(
    model_path="path/to/your/model",
    k_value=5  # Number of top trigger tokens
)

# Identify trigger tokens
result = detector.identify_trigger_tokens(text)

# Print results
detector.print_results(result)
```

### 3. Training KV Correction MLP

```python
from hedit import KVCorrectionDataset, KVCorrectionMLP, MLPTrainer
from torch.utils.data import DataLoader

# Load dataset
train_dataset = KVCorrectionDataset(
    data_dir="path/to/data",
    dataset_names=["dataset1", "dataset2"],
    layer_idx=40
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Create model
model = KVCorrectionMLP(
    input_dim=7168,  # hidden_dim + 2*kv_dim
    output_dim=2048,  # 2*kv_dim
    hidden_dim1=2048,
    hidden_dim2=1024
)

# Train
trainer = MLPTrainer(model, device='cuda')
trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=100,
    learning_rate=1e-4,
    save_dir="./checkpoints"
)
```

## 📚 Examples

See the `examples/` directory for complete examples:

- `examples/demo_training.py`: Complete training pipeline
- `examples/demo_inference.py`: Inference with KV correction

Run examples:

```bash
# Training example
python examples/demo_training.py --model_path /path/to/model --data_dir /path/to/data

# Inference example
python examples/demo_inference.py --model_path /path/to/model --mlp_path /path/to/checkpoint
```

## 🔧 Scripts

Convenience scripts are provided in the `scripts/` directory:

```bash
# Train KV correction MLP
bash scripts/train.sh

# Run inference with correction
bash scripts/infer.sh
```

## 📊 Configuration

Key parameters you can adjust:

### Anchor Token Detection
- `k_value`: Number of top anchor tokens to select (default: 5)
- `attention_layer`: Which layer's attention to use (default: 20)
- `ffn_layer_start/end`: Layers for FFN update calculation (default: 20-21)

### Trigger Token Detection
- `k_value`: Number of top trigger tokens to select (default: 5)

### Training
- `batch_size`: Training batch size (default: 16)
- `learning_rate`: Learning rate (default: 1e-4)
- `num_epochs`: Maximum training epochs (default: 100)
- `hidden_dim1/2`: MLP hidden dimensions (default: 2048/1024)

## 📁 Data Format

The expected data format for training:

```
data_dir/
├── layer_{idx}/
│   ├── {dataset_name}_trigger_hidden.pt  # Trigger hidden states
│   ├── {dataset_name}_error_kv.pt        # Error KV states
│   └── {dataset_name}_anchor_kv.pt       # Ground truth anchor KV
└── {dataset_name}_sample_info.json       # Sample metadata
```

## 🎯 Model Checkpoints

Trained model checkpoints will be saved in the format:

```python
{
    'model_state_dict': ...,      # Model weights
    'optimizer_state_dict': ...,  # Optimizer state
    'epoch': ...,                 # Training epoch
    'train_loss': ...,            # Training loss
    'val_loss': ...,              # Validation loss
    'model_config': {             # Model configuration
        'input_dim': ...,
        'output_dim': ...
    }
}
```

Load a checkpoint:

```python
checkpoint = torch.load('path/to/checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
```

## 🔬 Citation

If you use HEdit in your research, please cite our paper:

```bibtex
@article{hedit2025,
  title={Correcting in Hindsight: Editing Past Key-Value States for Robust LLM Reasoning},
  author={Anonymous},
  journal={ICML},
  year={2025}
}
```

## 📄 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## 🤝 Contributing

We welcome contributions! Please feel free to submit a Pull Request.

## 📮 Contact

For questions or feedback, please open an issue on GitHub.

## 🙏 Acknowledgments

This work builds upon the Transformers library by Hugging Face and PyTorch.

---

**Note**: This is research code. Model paths, data paths, and hyperparameters should be adjusted based on your specific setup and requirements. Some features may require additional preprocessing or data preparation not included in this release.
