# Claim-Consistency Coupling Experiment

A minimal, reproducible PyTorch experiment that tests whether a small decoder-only
transformer can be trained to couple rationale content with claim predictions,
and measures how different auxiliary loss strategies affect this coupling.

---

## Quick start

```bash
# 1. Install dependencies (only common packages needed)
pip install torch numpy pandas tqdm tabulate nbformat matplotlib

# 2. Run the smoke test (completes in ~30–90 s on CPU)
python claim_consistency_experiment.py

# 3. Or open the notebook
jupyter notebook claim_consistency_coupling_experiment.ipynb
```

---

## Files

| File | Purpose |
|------|---------|
| `claim_consistency_experiment.py` | All logic: dataset, model, training, evaluation |
| `claim_consistency_coupling_experiment.ipynb` | End-to-end notebook with charts and tables |
| `results_comparison.csv` | Metrics table written after each run |
| `results_comparison.md` | Markdown version of the results table |
| `README.md` | This file |

---

## What the experiment does

### Synthetic data

Sequences follow the layout:

```
[BOS] <prompt tokens> [SEP] <rationale tokens> [SEP] <claim tokens>
```

There are `num_latent_states` (default 8) states. Each state has:
- `num_rationale_templates` (default 4) different paraphrased token sequences
  for the rationale span — different surface form, same underlying state.
- One deterministic pair of claim label tokens.

The model never sees real language; all "meaning" is encoded in token identities
drawn from non-overlapping ranges per state. This makes the task perfectly learnable
if the model captures the right structure.

### Training variants

| Variant | Auxiliary loss |
|---------|---------------|
| `no_consistency_loss` | LM next-token loss only |
| `rationale_only` | LM + consistency CE on mean-pool of rationale tokens |
| `full_sequence` | LM + consistency CE on mean-pool of entire sequence |
| `earlier_token_only` | LM + consistency CE on mean-pool of prompt+rationale (excludes claim) |

The consistency head is a single linear layer projecting from `d_model` →
`num_latent_states`.  Loss weight is configurable (`consistency_loss_weight`,
default 0.5).

---

## Metrics

### Primary accuracy metrics

| Metric | What it measures |
|--------|-----------------|
| `gen_claim_acc` | Greedy generation after `[BOS + prompt + SEP + rationale + SEP]`: does the first generated token match the correct claim token for the true latent state? |
| `cls_claim_acc (rationale_pool)` | Mean-pool hidden states over rationale tokens → linear head → argmax. Correct if predicted class == true latent state. Computed for **all** variants using the rationale pool, even those trained with a different pooling mode. |

### Counterfactual swap test

For each sample, the rationale is replaced with tokens from a **different** latent state
while the claimed label is kept from the original state. This isolates whether the model
"reads" the rationale or the context for its prediction.

| Metric | What it measures |
|--------|-----------------|
| `cfact_gen_follows_swap` | Generated claim matches the **swapped** rationale's state (follows new evidence) |
| `cfact_gen_follows_orig` | Generated claim matches the **original** latent state (ignores new evidence) |
| `cfact_cls_follows_swap` | Classifier predicts the **swapped** state (strong coupling = high) |
| `cfact_cls_follows_orig` | Classifier predicts the **original** state despite swapped rationale (low coupling = high) |

**Interpretation:**
- Strong consistency coupling → `cfact_cls_follows_swap ≈ 1.0`, `cfact_cls_follows_orig ≈ 0.0`.
- No coupling (`no_consistency_loss`) → both near chance.

### Shuffled-pairing control

A separate dataset is generated with intentionally **mismatched** rationale-claim pairs
(rationale drawn from a random other latent state, claim label kept from original state).

| Metric | What it measures |
|--------|-----------------|
| `shuffled_gen_acc` | Generation accuracy on mismatched pairs (expect degradation if rationale is relied upon) |
| `shuffled_cls_acc` | Classifier accuracy on mismatched pairs (expect degradation) |

A well-coupled model (consistency-trained) should show lower `shuffled_cls_acc` because
the classifier reads the rationale, and when the rationale disagrees with the claim the
prediction is pulled toward the wrong state.

---

## Expected interpretation

With small models and few epochs (smoke test):

```
rationale_only:      cls_claim_acc ≈ 1.0   cfact_cls_follows_swap ≈ 1.0
full_sequence:       cls_claim_acc ≈ 0.9+  cfact_cls_follows_swap ≈ 1.0
earlier_token_only:  cls_claim_acc ≈ 1.0   cfact_cls_follows_swap ≈ 1.0
no_consistency_loss: cls_claim_acc ≈ 0.25  cfact_cls_follows_swap ≈ 0.25 (chance for 4 states)
```

Generation accuracy (`gen_claim_acc`) improves more slowly and requires more epochs,
as the LM must internalize the claim-rationale coupling in the residual stream without
explicit gradient supervision on that particular output position.

---

## Configuration

All parameters live in `ExperimentConfig` (a Python dataclass):

```python
from claim_consistency_experiment import ExperimentConfig, run_experiment

cfg = ExperimentConfig(
    num_latent_states=8,       # 8–16 recommended
    num_rationale_templates=4, # 3–5
    num_train_samples=2048,    # increase for proper training run
    num_eval_samples=256,
    num_epochs=20,             # increase for convergence
    d_model=128,               # increase for capacity
    n_layers=4,
    consistency_loss_weight=0.5,
    seed=42,
)
df = run_experiment(cfg)
```

---

## Design notes

- **No internet required.** All data is generated synthetically.
- **Deterministic.** All random seeds are controlled via `cfg.seed`.
- **CPU-friendly.** Default model is 2 layers, d_model=64. A full run on a modern CPU
  completes in under 2 minutes for the smoke config.
- **Hackable.** The module is a single ~700-line file with clearly separated sections.
  Swap in a real tokenizer, larger vocab, or actual language templates with minimal changes.
