# LeanCheck

LeanCheck tests whether hidden states over natural-language rationale spans can be trained to encode formal Lean proof-checker outcomes.

## Sequence Format

```text
[THEOREM]
<Lean theorem statement>

[PROOF]
<candidate Lean proof>

[RAT]
<natural-language rationale>

[CLAIM]
<VERIFIES or FAILS>
[EOS]
```

The verifier target is `y = lean_check(theorem, proof_candidate)`, binary-labeled as `VERIFIES` or `FAILS`.

## Files

- `leancheck_data.py`: dataset generation, Lean wrapper, mutation logic, deterministic fallback labels.
- `leancheck_model.py`: GPT-2 or offline tiny causal LM with consistency head and span pooling.
- `leancheck_eval.py`: metrics, counterfactual swap tests, and minimal-pair diagnostics.
- `leancheck_experiment.py`: main training/evaluation script.
- `run_leancheck_smoke.sh`: tiny local smoke run.
- `run_leancheck_full.sh`: full GPT-2/Lean run.
- `modal_leancheck.py`: Modal GPU entry point.

## Running

Local smoke test:

```bash
python3 -m venv .venv
.venv/bin/pip install torch transformers pandas tqdm
bash run_leancheck_smoke.sh
```

Full run with Lean and GPT-2:

```bash
bash run_leancheck_full.sh
```

Modal GPU run, once the Modal CLI is installed and authenticated:

```bash
modal run modal_leancheck.py
```

Modal run with activation patching:

```bash
modal run modal_leancheck.py \
  --mode full \
  --train-samples 1000 \
  --eval-samples 200 \
  --counterfactual-samples 200 \
  --minimal-pair-samples 100 \
  --epochs 3 \
  --run-name gpt2_lean_1k_patch \
  --activation-patching \
  --patch-layers 0,3,6,9,11 \
  --patch-samples 100
```

Modal full runs save artifacts to the persistent Volume `leancheck-artifacts/<run-name>/`, including:

- `data/leancheck_train.jsonl`
- `data/leancheck_eval.jsonl`
- `data/leancheck_counterfactual.jsonl`
- `data/leancheck_minimal_pairs.jsonl`
- `results_leancheck_<run-name>.csv`
- `results_leancheck_<run-name>.md`
- `leancheck_runs_<run-name>/<variant>/activation_patching.csv` when activation patching is enabled

## Dataset

The generator covers simple natural-number arithmetic, propositional logic, list basics, and tactic-style proof terms. Positive examples are accepted templates. Negative examples are generated by mutations:

- wrong lemma
- wrong theorem/proof pairing
- missing premise
- deleted final proof line
- renamed variable
- replacement tactic
- adversarial near miss

When `--use-lean` is set and `lean` exists on `PATH`, labels and raw error output come from Lean. Otherwise known template/mutation labels are used as a deterministic fallback.

## Variants

- `lm_only`: LM loss only.
- `no_consistency_loss`: head exists, consistency weight is zero.
- `rationale_only`: pool `[RAT] ... [CLAIM]`.
- `full_sequence`: pool theorem, proof, and rationale.
- `proof_only`: pool `[PROOF] ... [RAT]`.
- `claim_only`: pool around `[CLAIM]`.
- `random_consistency`: rationale span with shuffled labels.
- `wrong_span`: pool theorem span only.

## Metrics

- `final_lm_loss`: final next-token loss on eval.
- `final_cons_loss`: final consistency-head cross entropy.
- `gen_claim_acc`: whether the model's claim-token logits match the Lean label.
- `cls_claim_acc`: consistency-head accuracy.
- `cfact_cls_follows_swap`: on counterfactual swaps, classifier follows the swapped rationale label.
- `cfact_cls_follows_orig`: on counterfactual swaps, classifier follows the proof/original label.
- `minimal_pair_flip_acc`: accepted/rejected proof-edit pairs both predicted correctly.
- `shuffled_cls_acc`: random-label sanity-check accuracy.
- `wrong_span_cls_acc`: theorem-only span control accuracy.

## Activation Patching

Enable with `--activation-patching`. The diagnostic uses accepted/rejected minimal pairs and runs both directions: accepted base with rejected source activations, and rejected base with accepted source activations.

For selected GPT-2 layers, it:

- runs a clean base example
- runs a source example and caches hidden states at the selected layer
- patches source hidden states into the base example over the rationale span
- repeats with theorem and random-span controls
- measures `target_logit_after_patch - target_logit_before_patch` for the source label at the claim position
- also pools the patched final hidden states with the model variant span and measures the source-class consistency-head logit shift

Summary metrics are added to the main results table:

- `patch_rationale_effect`
- `patch_theorem_effect`
- `patch_random_effect`
- `patch_rationale_minus_random`
- `head_patch_rationale_effect`
- `head_patch_theorem_effect`
- `head_patch_random_effect`
- `head_patch_rationale_minus_random`

Raw per-layer rows are written to `activation_patching.csv` for each variant.

## Limitations

Templated rationales may make the task easier. Binary accept/reject is simpler than proof synthesis. Consistency-head accuracy proves decodability, not causal use. Activation patching or RL is needed to show stronger causal faithfulness. The dataset mostly covers simple Lean examples unless expanded.
