# Dynamic Mid-Generation Abstention for LLMs

This repository contains the code for reproducing the experiments in "Dynamic Mid-Generation Abstention is Optimal" (ICML 2026).

## Overview

We propose a method for **dynamic abstention** in large language models during chain-of-thought reasoning. Unlike prior approaches that make abstention decisions at fixed positions (e.g., before or after generation), our method allows the model to decide whether to continue or abstain at each token position, balancing computational efficiency with accuracy.

## Repository Structure

```
├── environment.yaml              # Conda environment specification
├── value_head_model.py           # Core model: ValueHeadModel + TokenwiseValueHead
├── datasets.py                   # Data loading and preprocessing
├── early_abstention.py           # Abstention evaluation and stopping criteria
├── utils.py                      # Utility functions for analysis
│
├── train_our_method.py           # Train dynamic abstention (our method)
├── train_first_token_baseline.py # Train prompt probe baseline (first token only)
├── train_lora_abstention.py      # Train LoRA abstention baseline
├── split_dataset.py              # Train/test dataset splitting
│
├── plot_abstention_rate_analysis.py  # Generate trajectory CSVs for all models
├── main_results.py                   # Figure 1: Selective accuracy plots
├── recalibrated_reward_plot.py       # Figure 2: Reward vs r_⊥ plots
└── calibration_comparison_figure.py  # Figure 3: Calibration comparison
```

## Installation

```bash
# Create conda environment
conda env create -f environment.yaml
conda activate dynamic-abstention

# Or install with pip
pip install torch transformers trl peft torchmetrics wandb \
            numpy pandas scipy scikit-learn matplotlib seaborn tqdm
```

## Data Format

The code expects JSONL files with the following structure:

```json
{
  "doc": {
    "question": "What is 2 + 3?",
    "correctness": [1, 0, 1]
  },
  "resps": [["Response 1", "Response 2", "Response 3"]]
}
```

Each question has multiple responses with corresponding correctness labels (1 = correct, 0 = incorrect).

## Training

### Dynamic Abstention (Our Method)

Train a value head on all output tokens for dynamic mid-generation abstention:

```bash
python train_our_method.py \
    --model_name "Qwen/Qwen2.5-Math-7B-Instruct" \
    --data_path data/train.jsonl \
    --output_dir outputs/dynamic \
    --device cuda:0
```

This trains a `TokenwiseValueHead` that predicts correctness at every token position, enabling abstention at any point during generation.

### Prompt Probe Baseline

Train a value head on only the first output token:

```bash
python train_first_token_baseline.py \
    --model_name "Qwen/Qwen2.5-Math-7B-Instruct" \
    --data_path data/train.jsonl \
    --output_dir outputs/baseline \
    --device cuda:0
```

This baseline can only abstain at the start of generation (no mid-generation decisions).

### LoRA Abstention Baseline

Fine-tune with LoRA to predict abstention tokens:

```bash
python train_lora_abstention.py \
    --model_name "Qwen/Qwen2.5-Math-7B-Instruct" \
    --data_path data/train.jsonl \
    --output_dir outputs/lora \
    --device cuda:0
```

## Evaluation

### Generate Trajectory CSVs

Run evaluation on all models to generate trajectory CSV files:

```bash
python plot_abstention_rate_analysis.py \
    --model-name "Qwen/Qwen2.5-Math-7B-Instruct" \
    --data-path data/test.jsonl \
    --baseline-path outputs/baseline/checkpoints_first_token/value_head_epoch_2.pth \
    --full-model-path outputs/dynamic/checkpoints_our_method/value_head_epoch_5.pth \
    --lora-model-path outputs/lora/final_model \
    --output-folder results/gsm8k_qwen \
    --device cuda:0
```

This evaluates all four methods (dynamic, prompt probe, LoRA, self-assessment) and saves `trajectory_values.csv` containing per-token value predictions.

### Generate Figures

After generating trajectory CSVs for each dataset/model combination, update the `trajectory_files` dictionary in each figure script to point to your CSVs:

```python
# Example paths (update in main_results.py, recalibrated_reward_plot.py, calibration_comparison_figure.py)
trajectory_files = {
    "./results/gsm8k_qwen/trajectory_values.csv": "gsm8k_qwen",
    "./results/gsm8k_phi3/trajectory_values.csv": "gsm8k_phi3",
    ...
}
```

Then run:

```bash
# Figure 1: Selective Accuracy vs Abstention Rate
python main_results.py

# Figure 2: Reward vs r_⊥
python recalibrated_reward_plot.py

# Figure 3: Calibration Comparison
python calibration_comparison_figure.py
```

Figures are saved to `output_plots/`.

## Key Components

### ValueHeadModel (`value_head_model.py`)

Wraps a frozen LLM backbone with a trainable value head:

```python
from value_head_model import ValueHeadModel, TokenwiseValueHead

value_head = TokenwiseValueHead(hidden_dim=4096)
model = ValueHeadModel(
    model_name_or_path="Qwen/Qwen2.5-Math-7B-Instruct",
    value_head=value_head,
    freeze_base_model=True,
    device="cuda"
)
```

### TokenwiseValueHead

Simple MLP that predicts a scalar value per token:

```
Linear(hidden_dim → hidden_dim) → Tanh → Dropout → Linear(hidden_dim → 1)
```

### Dynamic Abstention at Inference

```python
generated_ids, final_value = model.generate_with_abstention(
    input_ids,
    threshold=0.5,
    max_length=512,
    tokenizer=tokenizer
)
```

Generation stops early when the predicted value drops below the threshold.

## Supported Models

Tested with:
- `Qwen/Qwen2.5-Math-7B-Instruct`
- `Qwen/Qwen2.5-7B-Instruct`
- `microsoft/Phi-3-mini-4k-instruct`
- `mistralai/Mistral-7B-Instruct-v0.3`

## Notes

- **Weights & Biases**: Training scripts log to W&B by default. Set `wandb_mode: "offline"` in the config or `WANDB_MODE=disabled` to disable.

- **GPU memory**: The frozen LLM backbone requires ~14GB for 7B models. Gradient checkpointing is enabled by default.

## Citation

```bibtex
@inproceedings{anonymous2026dynamic,
  title={{Knowing When to Quit: A Principled Framework for Dynamic Abstention in LLM Reasoning}},
  author={Anonymous},
  booktitle={International Conference on Machine Learning},
  year={2026}
}
```

## License

MIT License
