# Consistency Loss Experiment

**Research question:** Does a consistency loss coupling explanation token hidden states to oracle-verified ground-truth claims cause a model to produce better-aligned explanations when paired with mismatched inputs?

## Setup

```bash
cd /path/to/consistency_loss_experiment

# Install dependencies
pip install torch numpy pandas matplotlib scikit-learn rouge-score scipy
```

## Quick Start

### Smoke Run (fast, ~2–5 min on CPU)

```bash
python run_experiment.py --smoke
```

Runs all 4 variants with a tiny model (~0.5M params), 300 examples, 5 epochs.
Proves all mechanisms work end-to-end.

### Full Run (configured target)

```bash
python run_experiment.py --full
```

Full config: 3,000 examples, 20 epochs, batch size 32, lr 5e-5, `small` model (~10M params).
Estimated runtime: ~2–4 hours on GPU, longer on CPU.

### Custom Options

```bash
# Custom epochs / batch size
python run_experiment.py --full --epochs 10 --batch 16

# Larger model (GPU required, ≥8GB VRAM)
python run_experiment.py --full --model gpt2_small

# Run only specific variants
python run_experiment.py --smoke --variants consistency_loss no_consistency_loss

# Custom smoke config
python run_experiment.py --smoke --smoke-n 500 --smoke-epochs 8 --smoke-batch 16
```

## Experiment Design

### Dataset (3,000 examples)

Generates Python functions covering:
- **Sorting**: bubble sort, selection sort, insertion sort
- **Searching**: linear search, binary-like patterns
- **String ops**: reversal, contains, character counting
- **Math**: factorial, sum, mean, absolute value

Each function has **ground-truth claims**:
- `time_complexity`: O(1), O(n), O(n²)
- `space_complexity`: O(1), O(n), O(n²)
- `correctness`: 0 (buggy) or 1 (correct)

Explanations are **randomly permuted** across functions (mismatched pairing).
Final format: `(code_snippet, mismatched_explanation, ground_truth_claims)`.

### Model Architecture

GPT-2-style causal Transformer (implemented from scratch in PyTorch):

| Config | d_model | Heads | Layers | d_ff | Params |
|--------|---------|-------|--------|------|--------|
| smoke  | 64      | 2     | 2      | 128  | ~0.5M  |
| small  | 256     | 4     | 4      | 1024 | ~10M   |
| gpt2_small | 768 | 12   | 12     | 3072 | ~117M  |

**Sequence format:**
```
<bos> [code] <sep> [explanation] <claim>time_complexity=O(n)</claim>
<claim>space_complexity=O(1)</claim> <claim>correctness=1</claim> <eos>
```

**Causal masking:** Explanation tokens appear *before* claim tokens in the sequence. Standard causal attention ensures explanation tokens cannot attend to future claim tokens. An explicit additive attention bias further enforces this constraint.

### Training Variants

| Variant | Description |
|---------|-------------|
| `consistency_loss` | LM loss + consistency loss (explanation pooling) |
| `no_consistency_loss` | LM loss only (baseline) |
| `claim_only_pooling` | LM loss + consistency loss on claim token pooling (negative ctrl) |
| `random_label_consistency` | LM loss + consistency loss with shuffled labels (negative ctrl) |

**Loss:**
```
total_loss = LM_loss + λ * consistency_loss   (λ = 1.0)
consistency_loss = CE(time_pred, true_time) + CE(space_pred, true_space) + CE(correct_pred, true_correct)
```

Consistency loss computed from **mean-pooled explanation hidden states** of the last Transformer layer.

### Validation Metrics (every epoch)

1. **Coupling strength**: Mean classifier accuracy (time, space, correctness) from explanation hidden states
2. **BLEU-1**: Unigram precision between generated and reference explanations
3. **ROUGE-L**: LCS-based F1 between generated and reference explanations
4. **Counterfactual swap influence**: Whether explanation hidden states follow swapped labels (proxy test)
5. **Claim accuracy**: Fraction of correct claim tokens in greedy-decoded output

## Outputs

All outputs written to `outputs/`:

```
outputs/
├── metrics.csv                        # Per-epoch metrics for all variants
├── report.pplx.md                     # Markdown report with findings
├── coupling_strength.png              # Chart: coupling strength all variants
├── explanation_correctness.png        # Chart: BLEU-1 + ROUGE-L all variants
├── counterfactual_swap.png            # Chart: swap influence all variants
├── claim_accuracy.png                 # Chart: claim token accuracy
├── losses.png                         # Chart: training loss curves
├── run_metadata.json                  # Run configuration and file paths
└── checkpoints/
    ├── consistency_loss/              # Saved every 5 epochs (or final)
    ├── no_consistency_loss/
    ├── claim_only_pooling/
    └── random_label_consistency/
```

## Critical Prediction

The `consistency_loss` variant should show:
- **Higher coupling strength** than other variants (explanation hidden states become more informative about claims)
- **Higher BLEU/ROUGE** if the consistency signal improves explanation alignment
- **Positive swap influence** (hidden states better predict own labels than swapped labels)

The `no_consistency_loss`, `claim_only_pooling`, and `random_label_consistency` variants serve as controls to isolate the specific mechanism.

## Notes on GPT-2

This experiment uses a from-scratch GPT-2-compatible Transformer architecture. To use pretrained GPT-2 weights, install `transformers` and modify `model.py` accordingly. The `gpt2_small` config is dimensionally identical to GPT-2 small (117M params) but trains from random initialization.

## File Structure

```
consistency_loss_experiment/
├── dataset.py         # Data generation, tokenizer, oracle verifier
├── model.py           # Causal Transformer architecture + loss functions
├── trainer.py         # Training loop, validation, qualitative eval
├── visualize.py       # Chart generation (PNG)
├── report.py          # Markdown report builder
├── run_experiment.py  # Main CLI entry point
└── README.md          # This file
```
