# Entropy-Adaptive Dual-Stream Language Model

**Anonymous ICML 2026 Submission**

This repository contains the implementation code for our paper on entropy-adaptive dual-stream language modeling. As per ICML guidelines, we provide the code architecture and training procedures.

## Overview

This work introduces a dual-stream architecture that dynamically balances syntactic (System 1) and semantic (System 2) processing. During training, we use entropy-weighted loss to teach the semantic head to specialize in high-uncertainty transitions. During inference, we apply variance alignment to enable effective semantic control.

## Requirements

```bash
pip install -r requirements.txt
```

**Minimum Hardware:**
- GPU: 24GB VRAM (tested on RTX 4090)
- RAM: 32GB
- Storage: 50GB for base model cache

## Quick Start

### 1. Setup Environment

```bash
chmod +x setup.sh
./setup.sh
source venv/bin/activate
```

### 2. Data Setup

The code uses **FineWeb-Edu** dataset by default (as described in the paper):

```python
# data.py automatically loads FineWeb-Edu from HuggingFace
# No manual download required
```

The `StreamDataset` class in `data.py` implements:
- Streaming from FineWeb-Edu (10BT sample)
- Block size: 512 tokens
- Train/val split: First ~7mn samples for training, rest for validation
- Returns: (input_ids, labels, future_tokens) where future_tokens is K=20

**To use a different dataset:** Modify the `load_dataset()` call in `data.py`.

### 3. Training

```bash
python train_entropy.py
```

Configuration can be modified directly in the `CONF` dictionary in `train_entropy.py`.

### 4. Compute Variance Scaling Factor

After training, compute the global variance scaling factor γ from validation set:

```bash
python compute_gamma.py --checkpoint ./checkpoints/best_model --num_samples 10000
```

The paper reports γ ≈ 3.42 for Mistral-7B on FineWeb-Edu.

### 5. Inference

```bash
python inference.py --checkpoint ./checkpoints/best_model --prompt "Your text here"
```

**Key inference parameters:**
- `--alpha 0.5`: Gating intensity (default from paper)
- `--gamma 3.42`: Variance scaling factor (default from paper)
- `--compute_gamma`: Optionally compute γ from your validation set

## Architecture

### Model Components

1. **Base Model**: Quantized LLM with LoRA adapters (System 1)
2. **Idea Head**: Semantic projection layer (System 2)
3. **Dynamic Gating**: Entropy-adaptive training + variance alignment inference

### Training vs Inference

**Training (Section 3.5, Equation 8):**
- Loss formula: `L = L_NTP + λ_base · w_t · L_Idea`
- L_NTP computed from token_logits (System 1 only)
- L_Idea computed from idea_logits (System 2 only)
- Entropy-weighted: `w_t = 1 + δ·H(p_token)` where δ=2.0
- NO variance scaling (boost=1.0)
- Only mean centering applied to idea logits

**Inference (Section 3.3):**
- Fixed gating intensity (α=0.5)
- Variance alignment enabled with global γ≈3.42
- Formula: `γ = E[σ(z_token)] / E[σ(z'_idea)]` computed from validation set
- Uses gated final_logits for generation

### Key Hyperparameters

| Parameter | Value | Description |
|-----------|-------|-------------|
| `alpha` | 0.5 | Gating intensity (fixed during inference) |
| `gamma` | 3.42 | Variance scaling factor (inference only) |
| `pos_weight` | 200 | Class imbalance correction |
| `stopword_cutoff` | 250 | Vocabulary masking threshold |
| `base_idea_weight` | 0.3 | Base semantic loss weight (λ_base) |
| `entropy_sensitivity` | 2.0 | Entropy scaling factor δ (training only) |


## File Structure

```
icml2026_final_submission/
├── README.md              # This file
├── requirements.txt       # Python dependencies
├── LICENSE               # MIT License
├── setup.sh              # Environment setup script
├── model.py              # IdeaGatedModel architecture
├── train_entropy.py      # Training loop with entropy weighting
├── data.py               # Dataset interface (to be implemented)
├── compute_gamma.py      # Utility to compute variance scaling factor
└── inference.py          # Inference with variance alignment
```

## Reproducibility Notes

1. **Hardware**: Results reported on NVIDIA RTX 4090 (24GB)
2. **Base Model**: Mistral-7B-v0.1 with 4-bit quantization
3. **Data**: FineWeb-Edu (HuggingFaceFW/fineweb-edu, sample-10BT)
4. **Variance Factor**: Compute γ from your validation set using `compute_gamma.py`

## Key Implementation Details

### Training (Equation 8)
The two losses are computed **independently**:
- L_NTP from token_logits (System 1, no gating)
- L_Idea from idea_logits (System 2, separate)
- Combined: `L = L_NTP + λ_base · w_t · L_Idea`

### Inference (Section 3.3)
Variance alignment is **enabled** during inference:
- Pre-computed global γ ≈ 3.42
- Applied as: `ẑ_idea = z'_idea · γ`
- Fixed α=0.5 for gating

## Citation

```bibtex
@inproceedings{anonymous2026entropy,
  title={Latent Semantic Planning: Constraining Autoregressive Generation to Ideas Before Words},
  author={Anonymous},
  booktitle={International Conference on Machine Learning},
  year={2026}
}
```

## License

This code is released under MIT License for research purposes.
