"""
run_hidden_state_intervention.py
=================================
Runs the hidden-state intervention / causal-patching evaluation for all 4
model variants using the same scaled-experiment hyperparameters as
run_scaled_experiment.py.

Outputs
-------
/home/user/workspace/results_hidden_state_intervention.csv
/home/user/workspace/results_hidden_state_intervention.md

Does NOT overwrite any existing results files.
"""

import sys
import os

sys.path.insert(0, "/home/user/workspace")

from claim_consistency_experiment import (
    ExperimentConfig,
    evaluate_hidden_state_intervention,
    _build_vocabulary,
    ClaimConsistencyDataset,
    DataLoader,
    collate_fn,
    train_one_variant,
    set_seed,
)

CSV_OUT = "/home/user/workspace/results_hidden_state_intervention.csv"
MD_OUT  = "/home/user/workspace/results_hidden_state_intervention.md"

# ---------------------------------------------------------------------------
# Config  – mirror the scaled experiment settings (but keep epochs manageable
# on CPU: 30 epochs * 5120 samples is fine, model is tiny)
# ---------------------------------------------------------------------------
cfg = ExperimentConfig(
    # Scaled dataset sizes
    num_train_samples=5120,
    num_eval_samples=512,
    num_shuffled_samples=256,
    num_epochs=30,

    # Same architecture as scaled run
    num_latent_states=8,
    num_rationale_templates=4,
    d_model=64,
    n_heads=4,
    n_layers=2,
    d_ff=128,
    dropout=0.1,
    vocab_size=128,
    max_seq_len=64,

    batch_size=32,
    lr=3e-4,
    consistency_loss_weight=0.5,
    seed=42,
    device="cpu",

    results_path=CSV_OUT,  # not used by intervention eval, but set for consistency
)

print("=" * 70)
print("HIDDEN-STATE INTERVENTION EVALUATION")
print("=" * 70)
print(f"  Config: {cfg.num_latent_states} latent states, "
      f"{cfg.n_layers} layers, d_model={cfg.d_model}")
print(f"  Training samples: {cfg.num_train_samples}, epochs: {cfg.num_epochs}")
print(f"  Output CSV: {CSV_OUT}")
print(f"  Output MD:  {MD_OUT}")
print("=" * 70)

# ---------------------------------------------------------------------------
# Train all 4 variants (replicates the scaled run but keeps models in memory)
# ---------------------------------------------------------------------------
set_seed(cfg.seed)
vocab = _build_vocabulary(cfg)

train_ds = ClaimConsistencyDataset(cfg, vocab, cfg.num_train_samples, shuffled=False, seed_offset=0)
eval_ds  = ClaimConsistencyDataset(cfg, vocab, cfg.num_eval_samples,  shuffled=False, seed_offset=100)
shuf_ds  = ClaimConsistencyDataset(cfg, vocab, cfg.num_shuffled_samples, shuffled=True, seed_offset=200)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,  collate_fn=collate_fn)
eval_loader  = DataLoader(eval_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
shuf_loader  = DataLoader(shuf_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

trained_models = {}
for pooling_mode in cfg.pooling_modes:
    print(f"\n[TRAIN] variant={pooling_mode}")
    model, history = train_one_variant(
        cfg, pooling_mode, train_loader, eval_loader, shuf_loader, vocab
    )
    final_lm   = history[-1]["lm_loss"]
    final_cons = history[-1]["cons_loss"]
    print(f"  Final LM loss={final_lm:.4f}  Cons loss={final_cons:.4f}")
    trained_models[pooling_mode] = model

# ---------------------------------------------------------------------------
# Run hidden-state intervention evaluation
# ---------------------------------------------------------------------------
print("\n[EVAL] Running hidden-state intervention eval ...")
df = evaluate_hidden_state_intervention(
    cfg=cfg,
    models=trained_models,
    n_samples=64,
    seed_offset=1337,
)

print("\n=== HIDDEN-STATE INTERVENTION RESULTS ===")
try:
    print(df.to_markdown(index=False))
except ImportError:
    print(df.to_string(index=False))

# ---------------------------------------------------------------------------
# Save results
# ---------------------------------------------------------------------------
df.to_csv(CSV_OUT, index=False)
print(f"\n[DONE] CSV saved to {CSV_OUT}")

md_lines = [
    "# Claim Consistency – Hidden-State Intervention / Causal Patching Results\n",
    "\n## Methodology\n",
    "\nFor each sample pair (original rationale from latent state A, swapped rationale from",
    " latent state B), the evaluation:\n",
    "1. Runs a forward pass with the **original** sequence and caches post-block hidden states",
    "   at rationale token positions after each transformer block.\n",
    "2. For each block `i`, runs a second forward pass with the **swapped** sequence but",
    "   **replaces** the hidden states at rationale positions after block `i` with the cached",
    "   original states — letting all subsequent blocks process the patched activations.\n",
    "3. Reads the logit at the final SEP position (immediately before the claim span) to",
    "   identify the predicted first claim token (greedy, single-token).\n",
    "4. Records whether the prediction matches the **original** latent state's claim token",
    "   (`intervention_follows_original_hs`) or the **swapped** state's claim token",
    "   (`intervention_follows_swapped_tokens`).\n",
    "\n**Claim prediction position**: logit at index `prefix_len − 1`",
    f" = position {1 + cfg.prompt_len + 1 + cfg.rationale_len} (0-based),",
    " consistent with greedy next-token generation used in `generate_claim()`.\n",
    "\n## Hyperparameters\n\n",
    f"| Parameter | Value |\n|---|---|\n",
    f"| num_train_samples | {cfg.num_train_samples} |\n",
    f"| num_epochs | {cfg.num_epochs} |\n",
    f"| num_latent_states | {cfg.num_latent_states} |\n",
    f"| n_layers | {cfg.n_layers} |\n",
    f"| d_model | {cfg.d_model} |\n",
    f"| n_heads | {cfg.n_heads} |\n",
    f"| d_ff | {cfg.d_ff} |\n",
    f"| n_intervention_samples | 64 |\n",
    f"| seed | {cfg.seed} |\n",
    "\n## Results\n\n",
]

with open(MD_OUT, "w") as f:
    for line in md_lines:
        f.write(line)
    try:
        f.write(df.to_markdown(index=False))
    except Exception:
        f.write(df.to_string(index=False))
    f.write("\n\n## Column Descriptions\n\n")
    f.write("- **variant**: Training objective variant (pooling mode for consistency loss)\n")
    f.write("- **layer**: 0-based transformer block index\n")
    f.write("- **patch_layer_id**: Human-readable block label (e.g. `block_0`)\n")
    f.write("- **intervention_follows_original_hs**: Fraction of samples where patching the "
            "hidden states at this block causes the model to predict the *original* latent "
            "state's claim token (higher = patched HS dominate)\n")
    f.write("- **intervention_follows_swapped_tokens**: Fraction of samples where the patched "
            "model still predicts the *swapped* rationale's claim token (higher = surface "
            "tokens still dominate despite patch)\n")
    f.write("- **n_samples**: Number of (orig, swap) sample pairs evaluated\n")

print(f"[DONE] Markdown saved to {MD_OUT}")
