# Claim-Consistency Coupling Experiment

A minimal, reproducible PyTorch experiment that tests whether a small decoder-only
transformer can be trained to couple rationale content with claim predictions,
and measures how different auxiliary loss strategies affect this coupling.

---

## Quick start

```bash
# 1. Install dependencies (only common packages needed)
pip install torch numpy pandas tqdm tabulate nbformat matplotlib

# 2. Run the smoke test (completes in ~30–90 s on CPU)
python claim_consistency_experiment.py

# 3. Run the hard overlapping-vocab experiment (4 variants, 10 epochs)
python run_hard_experiment.py

# 4. Or open the notebook
jupyter notebook claim_consistency_coupling_experiment.ipynb
```

---

## Files

| File | Purpose |
|------|---------|
| `claim_consistency_experiment.py` | All logic: dataset, model, training, evaluation |
| `run_hard_experiment.py` | Hard overlapping-vocab experiment runner (4 variants, 10 epochs) |
| `run_scaled_experiment.py` | Scaled experiment runner (larger dataset, 30 epochs) |
| `claim_consistency_coupling_experiment.ipynb` | End-to-end notebook with charts and tables |
| `results_comparison.csv` | Metrics table written after smoke-test run |
| `results_comparison.md` | Markdown version of the smoke-test results table |
| `results_comparison_hard.csv` | Metrics from hard overlapping-vocab experiment |
| `results_comparison_hard.md` | Markdown version of hard experiment results with coupling check |
| `results_comparison_scaled.csv` | Metrics from scaled experiment |
| `results_comparison_scaled.md` | Markdown version of scaled experiment results |
| `README.md` | This file |

---

## What the experiment does

### Synthetic data

Sequences follow the layout:

```
[BOS] <prompt tokens> [SEP] <rationale tokens> [SEP] <claim tokens>
```

There are `num_latent_states` (default 8) states. Each state has:
- `num_rationale_templates` (default 4) different paraphrased token sequences
  for the rationale span — different surface form, same underlying state.
- One deterministic pair of claim label tokens.

The model never sees real language; all "meaning" is encoded in token identities
drawn from non-overlapping ranges per state (default). This makes the task perfectly
learnable if the model captures the right structure.

**Hard overlapping-vocab variant** (`hard_overlap_vocab=True`, `overlap_fraction=0.5`):
A harder version where each template draws ~50% of its token positions from pools
shared with other states:
- *Shared tokens* (10–17): appear in all 8 states.
- *Group tokens* (18–49): pairs of adjacent states share an 8-token block.
- *Local tokens* (50–65): 2 tokens unique to each state (minority ~50% of positions).

No single token is perfectly diagnostic; the model must rely on token co-occurrence
and combination patterns across the 8-token rationale span.

### Training variants

| Variant | Auxiliary loss |
|---------|---------------|
| `no_consistency_loss` | LM next-token loss only |
| `rationale_only` | LM + consistency CE on mean-pool of rationale tokens |
| `full_sequence` | LM + consistency CE on mean-pool of entire sequence |
| `earlier_token_only` | LM + consistency CE on mean-pool of prompt+rationale (excludes claim) |
| `claim_only_pooling` | LM + consistency CE on mean-pool of **claim token positions only** (negative control, not run in hard experiment) |

The consistency head is a single linear layer projecting from `d_model` →
`num_latent_states`.  Loss weight is configurable (`consistency_loss_weight`,
default 0.5).

**`claim_only_pooling` negative control:** The consistency supervisor is applied to the
claim token hidden states (positions that already encode the correct answer) rather than
the rationale positions. This creates a shortcut: the model can satisfy the consistency
loss without coupling rationale representations to claims. The evaluation classifier
(`cls_claim_acc (rationale_pool)`) still pools over *rationale* tokens, so a gap between
training signal and evaluation pool reveals whether rationale encoding was truly learned.

---

## Metrics

### Primary accuracy metrics

| Metric | What it measures |
|--------|-----------------|
| `gen_claim_acc` | Greedy generation after `[BOS + prompt + SEP + rationale + SEP]`: does the first generated token match the correct claim token for the true latent state? |
| `cls_claim_acc (rationale_pool)` | Mean-pool hidden states over rationale tokens → linear head → argmax. Correct if predicted class == true latent state. Computed for **all** variants using the rationale pool, even those trained with a different pooling mode. |

### Counterfactual swap test

For each sample, the rationale is replaced with tokens from a **different** latent state
while the claimed label is kept from the original state. This isolates whether the model
"reads" the rationale or the context for its prediction.

| Metric | What it measures |
|--------|-----------------|
| `cfact_gen_follows_swap` | Generated claim matches the **swapped** rationale's state (follows new evidence) |
| `cfact_gen_follows_orig` | Generated claim matches the **original** latent state (ignores new evidence) |
| `cfact_cls_follows_swap` | Classifier predicts the **swapped** state (strong coupling = high) |
| `cfact_cls_follows_orig` | Classifier predicts the **original** state despite swapped rationale (low coupling = high) |

**Interpretation:**
- Strong consistency coupling → `cfact_cls_follows_swap ≈ 1.0`, `cfact_cls_follows_orig ≈ 0.0`.
- No coupling (`no_consistency_loss`) → both near chance.

### Shuffled-pairing control

A separate dataset is generated with intentionally **mismatched** rationale-claim pairs
(rationale drawn from a random other latent state, claim label kept from original state).

| Metric | What it measures |
|--------|-----------------|
| `shuffled_gen_acc` | Generation accuracy on mismatched pairs (expect degradation if rationale is relied upon) |
| `shuffled_cls_acc` | Classifier accuracy on mismatched pairs (expect degradation) |

A well-coupled model (consistency-trained) should show lower `shuffled_cls_acc` because
the classifier reads the rationale, and when the rationale disagrees with the claim the
prediction is pulled toward the wrong state.

---

## Expected interpretation

With small models and few epochs (smoke test):

```
rationale_only:      cls_claim_acc ≈ 1.0   cfact_cls_follows_swap ≈ 1.0
full_sequence:       cls_claim_acc ≈ 0.9+  cfact_cls_follows_swap ≈ 1.0
earlier_token_only:  cls_claim_acc ≈ 1.0   cfact_cls_follows_swap ≈ 1.0
no_consistency_loss: cls_claim_acc ≈ 0.25  cfact_cls_follows_swap ≈ 0.25 (chance for 4 states)
claim_only_pooling:  cls_claim_acc << 1.0  cfact_cls_follows_swap ≈ chance  (negative control)
```

For `claim_only_pooling`, `cls_claim_acc (rationale_pool)` and `cfact_cls_follows_swap` should be
weak (below the rationale-trained variants) because the consistency loss was never applied to
rationale token representations. The rationale positions carry no direct gradient pressure toward
encoding claim identity.

Generation accuracy (`gen_claim_acc`) improves more slowly and requires more epochs,
as the LM must internalize the claim-rationale coupling in the residual stream without
explicit gradient supervision on that particular output position.

---

## Configuration

All parameters live in `ExperimentConfig` (a Python dataclass):

```python
from claim_consistency_experiment import ExperimentConfig, run_experiment

cfg = ExperimentConfig(
    num_latent_states=8,       # 8–16 recommended
    num_rationale_templates=4, # 3–5
    num_train_samples=2048,    # increase for proper training run
    num_eval_samples=256,
    num_epochs=20,             # increase for convergence
    d_model=128,               # increase for capacity
    n_layers=4,
    consistency_loss_weight=0.5,
    seed=42,
    # Hard overlapping-vocab (optional)
    hard_overlap_vocab=False,  # set True for ~50% token overlap across states
    overlap_fraction=0.5,      # fraction of template positions from shared/group pools
)
df = run_experiment(cfg)
```

---

## Design notes

- **No internet required.** All data is generated synthetically.
- **Deterministic.** All random seeds are controlled via `cfg.seed`.
- **CPU-friendly.** Default model is 2 layers, d_model=64. A full run on a modern CPU
  completes in under 2 minutes for the smoke config.
- **Hackable.** The module is a single ~700-line file with clearly separated sections.
  Swap in a real tokenizer, larger vocab, or actual language templates with minimal changes.
