# Info-Gain Sampler for Masked Diffusion Models

A unified decoding framework for Masked Diffusion Models (MDMs) that combines trajectory planning with information-gain maximization. This repository provides an implementation of the **Info-Gain Sampler**, a flexible decoding strategy that supports multiple heuristic functions and can adapt to various generation tasks.

## Overview

The Info-Gain Sampler extends the PC-Sampler framework with information-theoretic action selection. It supports:

- **Multiple heuristic functions**: confidence, PC-value, negative entropy, margin, and uniform sampling
- **Flexible trajectory control**: position-aware weighting and stochastic position sampling
- **Unified interface**: all baseline methods (entropy, margin, confidence, etc.) are implemented as special cases of the base `generate` function
- **Compatible models**: LLaDA, Dream

## Quick Start

### Installation

```bash
git clone <repository-url>
cd Uncode-new
pip install -r requirements.txt
```

### Simple Usage

The simplest way to use Info-Gain Sampler is through the example script:

```bash
cd scripts
python example_usage.py
```

This will demonstrate:
1. Basic generation using PC-Sampler mode (equivalent to `candidate_number=1`)
2. Info-Gain Sampler with position sampling enabled

### Programmatic Usage

```python
from transformers import AutoTokenizer, AutoModel
from src.generate import generate
import torch

# Load model
model_name = "GSAI-ML/LLaDA-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.bfloat16).to("cuda:0")
model.eval()

# Prepare prompt
prompt_text = "Your prompt here"
messages = [{"role": "user", "content": prompt_text}]
prompt_str = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
prompt = tokenizer(prompt_str)['input_ids']
prompt = torch.tensor(prompt).to("cuda:0").unsqueeze(0)

# Generate with Info-Gain Sampler
output = generate(
    model=model,
    prompt=prompt,
    steps=256,
    gen_length=256,
    block_length=32,
    lambd=0.0,              # Position weighting
    alpha=100,              # Confidence clipping
    baseline_name="../data/baseline/reference_corpus.json",
    temperature=0.0,
    candidate_number=8,     # >1 enables Info-Gain mode
    position_temperature=0.2,  # >0 enables position sampling
    heuristic='confidence',  # Heuristic function
    mask_id=126336,
    is_dream=False
)

# Decode result
generated_text = tokenizer.batch_decode(output[:, prompt.shape[1]:], skip_special_tokens=True)[0]
print(generated_text)
```

## Info-Gain Sampler

### Key Concepts

**Info-Gain Sampler** is a decoding strategy that selects actions (token positions to decode) by maximizing information gain. It works in two modes:

1. **Traditional mode** (`candidate_number=1`):
   - Greedy selection based on heuristic scores
   - Equivalent to traditional uncertainty-based samplers

2. **Info-Gain mode** (`candidate_number>1`):
   - Samples multiple candidate actions
   - Evaluates each candidate by computing information gain
   - Selects the action that maximizes immediate cost - information gain


### Heuristic Functions

The sampler supports multiple heuristic functions for scoring token positions:

- **`confidence`** (default): Uses model confidence (probability of predicted token)
- **`pc`**: PC-Sampler heuristic with frequency-based calibration
- **`neg_entropy`**: Negative entropy (higher entropy = lower score)
- **`margin`**: Margin between top-1 and top-2 probabilities
- **`uniform`**: Uniform random sampling

### Parameters for Info-Gain Sampler

| Parameter | Description | Default | Notes |
|-----------|-------------|---------|-------|
| `candidate_number` | Number of candidate actions to sample | 1 | >1 enables Info-Gain mode |
| `position_temperature` | Temperature for position sampling | 0.0 | >0 enables stochastic sampling |
| `heuristic` | Heuristic function type | 'confidence' | See above for options |

### Baseline Methods as Special Cases

All baseline decoding methods are implemented as special cases of the base `generate` function:

- **`original`**: `candidate_number=1`, `heuristic='confidence'`
- **`pc_sampler`**: `candidate_number=1`, `heuristic='pc'`
- **`entropy`**: `candidate_number=1`, `heuristic='neg_entropy'`
- **`margin`**: `candidate_number=1`, `heuristic='margin'`

This unified interface makes it easy to compare different methods and experiment with new heuristics.

## Evaluation

For comprehensive evaluation on multiple datasets, use the `eval.py` script:

```bash
cd scripts
python eval.py \
    --task humaneval \
    --model_name GSAI-ML/LLaDA-8B-Instruct \
    --device cuda:0 \
    --mode info-gain \
    --heuristic confidence \
    --candidate_number 8 \
    --position_temperature 0.2 \
    --data_path ../data/humaneval.jsonl \
    --result_path ../results/humaneval_info_gain
```

See `scripts/Eval.sh` for batch evaluation examples.

## Architecture

The codebase is organized as follows:

- **`src/generate.py`**: Core generation functions
  - `generate()`: Base generation function
  - `generate_with_info_gain()`: Wrapper for Info-Gain mode
  - `generate_with_eb_sampler()`: EB-Sampler baseline
  - `generate_with_fast_dllm()`: Fast-dLLM baseline
  - Helper functions for Info-Gain computation

- **`scripts/eval.py`**: Evaluation script for multiple tasks
- **`scripts/example_usage.py`**: Simple usage examples
- **`scripts/Eval.sh`**: Batch evaluation scripts
