{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "452a6530",
   "metadata": {},
   "source": [
    "# Claim-Consistency Coupling Experiment\n",
    "\n",
    "This notebook validates **claim-consistency coupling** in a small decoder-only transformer.\n",
    "\n",
    "## What we're testing\n",
    "\n",
    "A decoder-only transformer is trained on synthetic sequences of the form:\n",
    "\n",
    "```\n",
    "[BOS] <prompt tokens> [SEP] <rationale tokens> [SEP] <claim tokens>\n",
    "```\n",
    "\n",
    "Each latent state has:\n",
    "- Several paraphrased rationale templates (different tokens, same meaning)\n",
    "- One deterministic claim label\n",
    "\n",
    "We train **four variants** and evaluate how well rationale content predicts claim output:\n",
    "\n",
    "| Variant | Consistency loss pooling |\n",
    "|---|---|\n",
    "| `no_consistency_loss` | LM loss only |\n",
    "| `rationale_only` | Pool over rationale span only |\n",
    "| `full_sequence` | Pool over entire sequence |\n",
    "| `earlier_token_only` | Pool over prompt + rationale (pre-claim) |\n",
    "\n",
    "## Metrics\n",
    "\n",
    "| Metric | Meaning |\n",
    "|---|---|\n",
    "| `gen_claim_acc` | Does greedy generation after prompt+rationale produce the correct claim token? |\n",
    "| `cls_claim_acc (rationale_pool)` | Does mean-pooled rationale hidden state classify claim correctly? |\n",
    "| `cfact_gen_follows_swap` | When rationale is swapped from another state, does generation follow the **swapped** rationale? |\n",
    "| `cfact_gen_follows_orig` | When rationale is swapped, does generation follow the **original** claim? |\n",
    "| `cfact_cls_follows_swap` | Classifier follows swapped rationale (good coupling = high) |\n",
    "| `cfact_cls_follows_orig` | Classifier follows original claim despite swap (low coupling = high) |\n",
    "| `shuffled_gen_acc` | Generation accuracy on mismatched rationale-claim pairs (control) |\n",
    "| `shuffled_cls_acc` | Classifier accuracy on mismatched pairs (control) |\n",
    "\n",
    "**Key interpretation:** a model with strong claim-consistency coupling should show high `cls_claim_acc` and high `cfact_cls_follows_swap` for variants with consistency loss, while `no_consistency_loss` shows near-chance classifier performance.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "471fd25f",
   "metadata": {},
   "source": [
    "## 1. Setup & Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b380f340",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "# Ensure the workspace module is importable\n",
    "sys.path.insert(0, os.path.dirname(os.path.abspath(\"claim_consistency_experiment.py\")))\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "matplotlib.rcParams[\"figure.dpi\"] = 110\n",
    "\n",
    "from claim_consistency_experiment import (\n",
    "    ExperimentConfig,\n",
    "    _build_vocabulary,\n",
    "    ClaimConsistencyDataset,\n",
    "    ClaimConsistencyTransformer,\n",
    "    run_experiment,\n",
    "    evaluate_claim_accuracy_generation,\n",
    "    evaluate_claim_accuracy_classifier,\n",
    "    evaluate_counterfactual_swap,\n",
    "    evaluate_shuffled_pairing,\n",
    "    train_one_variant,\n",
    "    collate_fn,\n",
    "    set_seed,\n",
    ")\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "print(\"torch version:\", torch.__version__)\n",
    "print(\"device: cpu (smoke run)\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07dc33ea",
   "metadata": {},
   "source": [
    "## 2. Configuration\n",
    "\n",
    "The default config is sized for a **fast CPU smoke test** (< 60 s total).\n",
    "Increase `num_epochs`, `num_train_samples`, `d_model`, `n_layers` for a more thorough run.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5fda3d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = ExperimentConfig(\n",
    "    # Dataset\n",
    "    num_latent_states=8,          # 8 latent states (configurable 8-16)\n",
    "    num_rationale_templates=4,    # 4 paraphrased templates per state\n",
    "    num_train_samples=512,\n",
    "    num_eval_samples=128,\n",
    "    num_shuffled_samples=128,\n",
    "\n",
    "    # Model — small for CPU speed\n",
    "    d_model=64,\n",
    "    n_heads=4,\n",
    "    n_layers=2,\n",
    "    d_ff=128,\n",
    "\n",
    "    # Training\n",
    "    num_epochs=5,                 # ← increase for better convergence\n",
    "    batch_size=32,\n",
    "    lr=3e-4,\n",
    "    consistency_loss_weight=0.5,\n",
    "\n",
    "    seed=42,\n",
    "    device=\"cpu\",\n",
    "    results_path=\"results_comparison.csv\",\n",
    ")\n",
    "print(cfg)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6b7e485",
   "metadata": {},
   "source": [
    "## 3. Dataset Inspection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98fa65e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(cfg.seed)\n",
    "vocab = _build_vocabulary(cfg)\n",
    "\n",
    "# Show vocabulary structure for first 3 latent states\n",
    "for s in range(3):\n",
    "    print(f\"State {s}: claim_tokens={vocab[s]['claim_tokens']}\")\n",
    "    for i, t in enumerate(vocab[s]['rationale_templates']):\n",
    "        print(f\"  template {i}: {t}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ec737c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from claim_consistency_experiment import make_sample\n",
    "import numpy as np\n",
    "\n",
    "# Inspect one sample\n",
    "rng = np.random.default_rng(0)\n",
    "sample = make_sample(cfg, vocab, latent_state=2, rng=rng)\n",
    "\n",
    "print(\"token_ids      :\", sample[\"token_ids\"].tolist())\n",
    "print(\"targets        :\", sample[\"targets\"].tolist())\n",
    "print(\"rationale_mask :\", sample[\"rationale_mask\"].tolist())\n",
    "print(\"full_seq_mask  :\", sample[\"full_seq_mask\"].tolist())\n",
    "print(\"earlier_tok_mask:\", sample[\"earlier_tok_mask\"].tolist())\n",
    "print(\"latent_state   :\", sample[\"latent_state\"])\n",
    "print(\"claim_label    :\", sample[\"claim_label\"])\n",
    "print(\"is_shuffled    :\", sample[\"is_shuffled\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b37b4dc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize sequence layout\n",
    "import matplotlib.patches as mpatches\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(12, 2.5))\n",
    "tids = sample[\"token_ids\"].tolist()\n",
    "rmask = sample[\"rationale_mask\"].tolist()\n",
    "emask = sample[\"earlier_tok_mask\"].tolist()\n",
    "\n",
    "colors = []\n",
    "for i, t in enumerate(tids):\n",
    "    if rmask[i]:\n",
    "        colors.append(\"#4CAF50\")   # rationale = green\n",
    "    elif emask[i] and not rmask[i]:\n",
    "        colors.append(\"#2196F3\")   # prompt (in earlier_tok) = blue\n",
    "    else:\n",
    "        colors.append(\"#FF9800\")   # claim = orange\n",
    "\n",
    "for i, (tok, col) in enumerate(zip(tids, colors)):\n",
    "    ax.barh(0, 1, left=i, color=col, edgecolor=\"white\", linewidth=0.8)\n",
    "    ax.text(i + 0.5, 0, str(tok), ha=\"center\", va=\"center\", fontsize=7, color=\"white\", fontweight=\"bold\")\n",
    "\n",
    "ax.set_xlim(0, len(tids))\n",
    "ax.set_yticks([])\n",
    "ax.set_xlabel(\"Position\")\n",
    "ax.set_title(f\"Sample layout — latent_state={sample['latent_state']}  (green=rationale, blue=prompt, orange=claim)\")\n",
    "patches = [\n",
    "    mpatches.Patch(color=\"#2196F3\", label=\"Prompt (BOS+prompt+SEP)\"),\n",
    "    mpatches.Patch(color=\"#4CAF50\", label=\"Rationale\"),\n",
    "    mpatches.Patch(color=\"#FF9800\", label=\"SEP + Claim\"),\n",
    "]\n",
    "ax.legend(handles=patches, loc=\"upper right\", fontsize=8)\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0ae4065",
   "metadata": {},
   "source": [
    "## 4. Run Experiment (all 4 variants)\n",
    "\n",
    "This cell trains and evaluates all four variants sequentially.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce48dc03",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results = run_experiment(cfg)\n",
    "df_results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8143985e",
   "metadata": {},
   "source": [
    "## 5. Results Table & Visualisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23b1240f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If already have saved results, load them\n",
    "import os\n",
    "if os.path.exists(\"results_comparison.csv\"):\n",
    "    df_results = pd.read_csv(\"results_comparison.csv\")\n",
    "df_results.set_index(\"variant\", inplace=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40299e66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ── Bar charts ──\n",
    "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
    "variants = df_results[\"variant\"].tolist()\n",
    "x = range(len(variants))\n",
    "\n",
    "def bar_group(ax, col, title, ylabel=\"Accuracy\", ylim=(0, 1.05)):\n",
    "    ax.bar(x, df_results[col], color=[\"#607D8B\",\"#4CAF50\",\"#2196F3\",\"#FF5722\"])\n",
    "    ax.set_xticks(list(x))\n",
    "    ax.set_xticklabels(variants, rotation=15, ha=\"right\", fontsize=9)\n",
    "    ax.set_title(title, fontsize=11)\n",
    "    ax.set_ylabel(ylabel)\n",
    "    ax.set_ylim(*ylim)\n",
    "    ax.axhline(1/cfg.num_latent_states, color=\"red\", linestyle=\"--\", linewidth=1,\n",
    "               label=f\"Chance ({1/cfg.num_latent_states:.2f})\")\n",
    "    ax.legend(fontsize=8)\n",
    "\n",
    "bar_group(axes[0], \"gen_claim_acc\",\n",
    "          \"Claim Accuracy\\n(generation)\")\n",
    "bar_group(axes[1], \"cls_claim_acc (rationale_pool)\",\n",
    "          \"Claim Accuracy\\n(rationale-pool classifier)\")\n",
    "bar_group(axes[2], \"cfact_cls_follows_swap\",\n",
    "          \"Counterfactual:\\nclassifier follows swapped rationale\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"results_bar_chart.png\", dpi=120, bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "print(\"Chart saved to results_bar_chart.png\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a18e8f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ── Shuffled-pairing degradation ──\n",
    "fig, ax = plt.subplots(figsize=(7, 4))\n",
    "w = 0.35\n",
    "xi = np.arange(len(variants))\n",
    "\n",
    "ax.bar(xi - w/2, df_results[\"cls_claim_acc (rationale_pool)\"],\n",
    "       width=w, label=\"Normal data (classifier)\", color=\"#4CAF50\")\n",
    "ax.bar(xi + w/2, df_results[\"shuffled_cls_acc\"],\n",
    "       width=w, label=\"Shuffled pairs (classifier)\", color=\"#F44336\")\n",
    "\n",
    "ax.axhline(1/cfg.num_latent_states, color=\"gray\", linestyle=\"--\", linewidth=1,\n",
    "           label=f\"Chance ({1/cfg.num_latent_states:.2f})\")\n",
    "ax.set_xticks(xi)\n",
    "ax.set_xticklabels(variants, rotation=15, ha=\"right\")\n",
    "ax.set_title(\"Shuffled-Pairing Control:\\nClassifier Accuracy Degradation\")\n",
    "ax.set_ylabel(\"Accuracy\")\n",
    "ax.set_ylim(0, 1.05)\n",
    "ax.legend(fontsize=9)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"results_shuffled_control.png\", dpi=120, bbox_inches=\"tight\")\n",
    "plt.show()\n",
    "print(\"Chart saved to results_shuffled_control.png\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bba075e",
   "metadata": {},
   "source": [
    "## 6. Interpretation Guide\n",
    "\n",
    "### What the metrics tell us\n",
    "\n",
    "**`cls_claim_acc (rationale_pool)`**\n",
    "- Variants with consistency loss should reach near 1.0 — the hidden states over the rationale span have been forced to encode the claim label.\n",
    "- `no_consistency_loss` stays near chance (~1/num_latent_states) because without the auxiliary loss, there is no gradient signal to push claim identity into rationale representations.\n",
    "\n",
    "**`gen_claim_acc`**\n",
    "- Generation accuracy is determined by the LM head. The consistency head does not directly affect generation unless the shared hidden states are altered.\n",
    "- Modest improvements over `no_consistency_loss` indicate multi-task synergy; large improvements would suggest strong coupling.\n",
    "\n",
    "**Counterfactual swap test**\n",
    "- `cfact_cls_follows_swap = 1.0` for consistency-trained variants: the classifier has learned to read claim from rationale tokens, so when the rationale changes, the classifier prediction changes.\n",
    "- `cfact_cls_follows_orig = 0.0` for those same variants: the classifier does NOT follow the original claim label when rationale disagrees — strong coupling.\n",
    "- `no_consistency_loss`: classifier predictions are roughly random for both columns.\n",
    "\n",
    "**Shuffled-pairing control**\n",
    "- On mismatched (rationale, claim) pairs, the classifier should degrade toward chance for consistency-trained models that rely on rationale features.\n",
    "- `no_consistency_loss` classifier shows near-chance on both normal and shuffled data.\n",
    "\n",
    "### Expected pattern with longer training\n",
    "\n",
    "```\n",
    "rationale_only pooling  → cls_claim_acc ≈ 1.0,  cfact_cls_follows_swap ≈ 1.0\n",
    "full_sequence pooling   → similar, diluted by claim tokens in pool\n",
    "earlier_token_only      → similar to rationale_only (excludes claim tokens)\n",
    "no_consistency_loss     → cls_claim_acc ≈ chance,  generation weakly coupled\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5071011",
   "metadata": {},
   "source": [
    "## 7. Final Comparison Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9bcbda0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Render as styled DataFrame\n",
    "df_display = df_results.copy()\n",
    "df_display = df_display.set_index(\"variant\")\n",
    "\n",
    "def highlight_best(s):\n",
    "    is_max = s == s.max()\n",
    "    return [\"background-color: #c8e6c9\" if v else \"\" for v in is_max]\n",
    "\n",
    "styled = df_display.style.apply(highlight_best, axis=0)\n",
    "styled\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8f6f978",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save final markdown table\n",
    "try:\n",
    "    md_table = df_results.to_markdown(index=False)\n",
    "except ImportError:\n",
    "    md_table = df_results.to_string(index=False)\n",
    "\n",
    "with open(\"results_comparison.md\", \"w\") as f:\n",
    "    f.write(\"# Claim-Consistency Coupling — Results\\n\\n\")\n",
    "    f.write(md_table)\n",
    "    f.write(\"\\n\")\n",
    "\n",
    "print(\"Markdown table saved to results_comparison.md\")\n",
    "print()\n",
    "print(md_table)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "hs_intervention_md",
   "metadata": {},
   "source": [
    "## 8. Hidden-State Intervention / Causal Patching Evaluation\n",
    "\n",
    "This section runs `evaluate_hidden_state_intervention()`, which performs a **causal patching** test\n",
    "to determine whether claim predictions are causally driven by the rationale's *hidden states* or\n",
    "by its *surface tokens*.\n",
    "\n",
    "### Protocol\n",
    "\n",
    "For each (orig_state A, swap_state B) pair:\n",
    "1. Build an **original** sequence with rationale tokens from state A.\n",
    "2. Build a **swapped** sequence with rationale tokens from state B (but same claim label A).\n",
    "3. Cache post-block hidden states at rationale positions from the original forward pass.\n",
    "4. For each transformer block `i`:\n",
    "   - Re-run the swapped sequence, but after block `i` **replace** rationale hidden states\n",
    "     with the cached originals (patch callback injected via `TransformerBlock.forward`).\n",
    "   - Read `argmax` logit at the last-SEP position (= predicted first claim token).\n",
    "   - Record whether the patched prediction matches state A (`intervention_follows_original_hs`)\n",
    "     or state B (`intervention_follows_swapped_tokens`).\n",
    "\n",
    "A value of `intervention_follows_original_hs ≈ 1.0` at layer i means that the model's\n",
    "claim decision is fully determined by the hidden representations at that block — the surface\n",
    "tokens of the swapped rationale are overridden.  `≈ 0` means surface tokens dominate\n",
    "(the patch was too early / had no effect on later computation).\n",
    "\n",
    "**Claim prediction position**: logit index `prefix_len − 1` (final SEP before claim span),\n",
    "identical to the position used in `generate_claim()` greedy decoding.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "hs_intervention_code",
   "metadata": {},
   "outputs": [],
   "source": [
    "from claim_consistency_experiment import evaluate_hidden_state_intervention\n",
    "\n",
    "# Run the hidden-state intervention evaluation.\n",
    "# Uses the same cfg defined in section 2; trains all 4 variants internally\n",
    "# if `models` is not supplied.  Pass pre-trained models to skip retraining.\n",
    "df_hs = evaluate_hidden_state_intervention(cfg=cfg, n_samples=64, seed_offset=1337)\n",
    "\n",
    "print(\"\\n=== HIDDEN-STATE INTERVENTION RESULTS ===\")\n",
    "df_hs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "hs_intervention_plot",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ── Visualise intervention results ──\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "variants_hs = df_hs['variant'].unique().tolist()\n",
    "layers_hs   = sorted(df_hs['layer'].unique().tolist())\n",
    "n_layers_hs = len(layers_hs)\n",
    "\n",
    "fig, axes = plt.subplots(1, n_layers_hs, figsize=(5 * n_layers_hs, 5), sharey=True)\n",
    "if n_layers_hs == 1:\n",
    "    axes = [axes]\n",
    "\n",
    "x = np.arange(len(variants_hs))\n",
    "w = 0.35\n",
    "\n",
    "for ax, layer_idx in zip(axes, layers_hs):\n",
    "    sub = df_hs[df_hs['layer'] == layer_idx]\n",
    "    orig_vals  = [sub[sub['variant'] == v]['intervention_follows_original_hs'].values[0] for v in variants_hs]\n",
    "    swap_vals  = [sub[sub['variant'] == v]['intervention_follows_swapped_tokens'].values[0] for v in variants_hs]\n",
    "    ax.bar(x - w/2, orig_vals, width=w, label='Follows original HS', color='#4CAF50')\n",
    "    ax.bar(x + w/2, swap_vals, width=w, label='Follows swapped tokens', color='#F44336')\n",
    "    ax.set_title(f'Patch after block_{layer_idx}', fontsize=11)\n",
    "    ax.set_xticks(x)\n",
    "    ax.set_xticklabels(variants_hs, rotation=15, ha='right', fontsize=9)\n",
    "    ax.set_ylim(0, 1.05)\n",
    "    ax.axhline(1/cfg.num_latent_states, color='gray', linestyle='--', linewidth=1,\n",
    "               label=f'Chance ({1/cfg.num_latent_states:.2f})')\n",
    "    ax.legend(fontsize=8)\n",
    "\n",
    "axes[0].set_ylabel('Rate', fontsize=11)\n",
    "fig.suptitle('Hidden-State Intervention: Does patched HS override surface tokens?', fontsize=13)\n",
    "plt.tight_layout()\n",
    "plt.savefig('results_hidden_state_intervention.png', dpi=120, bbox_inches='tight')\n",
    "plt.show()\n",
    "print('Chart saved to results_hidden_state_intervention.png')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "hs_intervention_save",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save intervention results\n",
    "df_hs.to_csv('results_hidden_state_intervention.csv', index=False)\n",
    "try:\n",
    "    md_hs = df_hs.to_markdown(index=False)\n",
    "except ImportError:\n",
    "    md_hs = df_hs.to_string(index=False)\n",
    "with open('results_hidden_state_intervention.md', 'w') as f:\n",
    "    f.write('# Claim Consistency – Hidden-State Intervention Results\\n\\n')\n",
    "    f.write(md_hs)\n",
    "    f.write('\\n')\n",
    "print('Results saved to results_hidden_state_intervention.csv / .md')\n",
    "print(md_hs)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
