"""
run_hard_experiment.py
======================
Hard overlapping-vocabulary experiment for the claim-consistency coupling study.

Key differences from the default/scaled experiment:
  - hard_overlap_vocab=True: rationale token vocabularies share ~50% of their
    token positions across latent states (see _build_vocabulary_hard docstring).
  - 4 training variants only (no claim_only_pooling):
        no_consistency_loss, rationale_only, full_sequence, earlier_token_only
  - 10 epochs
  - All other hyperparameters match the default/smoke config:
      num_latent_states=8, num_rationale_templates=4
      d_model=64, n_layers=2, n_heads=4, d_ff=128
      batch_size=32, lr=3e-4, consistency_loss_weight=0.5
      num_train_samples=512, num_eval_samples=128, num_shuffled_samples=128
      seed=42, device=cpu, vocab_size=128

Outputs:
    /home/user/workspace/results_comparison_hard.csv
    /home/user/workspace/results_comparison_hard.md

Strong-coupling threshold check (in markdown output):
    cls_claim_acc (rationale_pool) > 0.90
    cfact_cls_follows_swap > 0.90
"""

import sys
import os

sys.path.insert(0, "/home/user/workspace")

from claim_consistency_experiment import (
    ExperimentConfig,
    run_experiment,
)

import pandas as pd

HARD_CSV = "/home/user/workspace/results_comparison_hard.csv"
HARD_MD  = "/home/user/workspace/results_comparison_hard.md"

# ── Config ──────────────────────────────────────────────────────────────────
cfg = ExperimentConfig(
    # Hard overlapping vocab
    hard_overlap_vocab=True,
    overlap_fraction=0.5,

    # Dataset sizes: same as default smoke/default
    num_train_samples=512,
    num_eval_samples=128,
    num_shuffled_samples=128,

    # 10 epochs
    num_epochs=10,

    # Model (default)
    num_latent_states=8,
    num_rationale_templates=4,
    d_model=64,
    n_heads=4,
    n_layers=2,
    d_ff=128,
    dropout=0.1,
    vocab_size=128,
    max_seq_len=64,

    # Training (default)
    batch_size=32,
    lr=3e-4,
    consistency_loss_weight=0.5,
    seed=42,
    device="cpu",

    # 4 variants only (no claim_only_pooling)
    pooling_modes=(
        "no_consistency_loss",
        "rationale_only",
        "full_sequence",
        "earlier_token_only",
    ),

    results_path=HARD_CSV,
)

# ── Print config ─────────────────────────────────────────────────────────────
print("=" * 70)
print("HARD OVERLAPPING-VOCAB EXPERIMENT CONFIG")
print("=" * 70)
print(f"  hard_overlap_vocab   = {cfg.hard_overlap_vocab}")
print(f"  overlap_fraction     = {cfg.overlap_fraction}")
print(f"  num_train_samples    = {cfg.num_train_samples}")
print(f"  num_eval_samples     = {cfg.num_eval_samples}")
print(f"  num_shuffled_samples = {cfg.num_shuffled_samples}")
print(f"  num_epochs           = {cfg.num_epochs}")
print(f"  num_latent_states    = {cfg.num_latent_states}")
print(f"  num_rationale_templates = {cfg.num_rationale_templates}")
print(f"  d_model              = {cfg.d_model}")
print(f"  n_layers             = {cfg.n_layers}")
print(f"  batch_size           = {cfg.batch_size}")
print(f"  lr                   = {cfg.lr}")
print(f"  consistency_loss_weight = {cfg.consistency_loss_weight}")
print(f"  pooling_modes        = {cfg.pooling_modes}")
print(f"  results_path         = {cfg.results_path}")
print("=" * 70)

# ── Run experiment ────────────────────────────────────────────────────────────
df = run_experiment(cfg)

print("\n=== HARD OVERLAPPING-VOCAB RESULTS ===")
try:
    print(df.to_markdown(index=False))
except ImportError:
    print(df.to_string(index=False))

# ── Strong coupling threshold check ─────────────────────────────────────────
CLS_THRESHOLD  = 0.90
CFACT_THRESHOLD = 0.90

def _check(row):
    cls_ok   = row["cls_claim_acc (rationale_pool)"] > CLS_THRESHOLD
    cfact_ok = row["cfact_cls_follows_swap"]          > CFACT_THRESHOLD
    return cls_ok, cfact_ok

print("\n=== STRONG COUPLING THRESHOLD CHECK ===")
print(f"  cls_claim_acc > {CLS_THRESHOLD}  AND  cfact_cls_follows_swap > {CFACT_THRESHOLD}")
for _, row in df.iterrows():
    cls_ok, cfact_ok = _check(row)
    status = "PASS" if (cls_ok and cfact_ok) else "FAIL"
    print(f"  {row['variant']}: cls={row['cls_claim_acc (rationale_pool)']:.4f} "
          f"cfact={row['cfact_cls_follows_swap']:.4f}  [{status}]")

# ── Save markdown ─────────────────────────────────────────────────────────────
with open(HARD_MD, "w") as f:
    f.write("# Claim Consistency Coupling — Hard Overlapping-Vocab Experiment\n\n")
    f.write("## Overview\n\n")
    f.write("This experiment re-runs the four consistency-loss training variants on a\n")
    f.write("**harder synthetic dataset** where rationale token vocabularies overlap by\n")
    f.write(f"approximately {int(cfg.overlap_fraction * 100)}% across latent states.  ")
    f.write("Instead of each state having a fully private token range, templates are\n")
    f.write("constructed from:\n\n")
    f.write("- **Shared tokens** (appear in all 8 states)\n")
    f.write("- **Group tokens** (appear in 2 adjacent states)\n")
    f.write("- **Local tokens** (unique to one state, minority ~50% of positions)\n\n")
    f.write("The model cannot classify states by single unique token identities; it must\n")
    f.write("learn co-occurrence / combination patterns across the 8-token rationale span.\n")
    f.write("Claim tokens remain fully state-specific (non-overlapping).\n\n")

    f.write("## Hyperparameters\n\n")
    f.write("| Parameter | Value |\n|---|---|\n")
    params = [
        ("hard_overlap_vocab", cfg.hard_overlap_vocab),
        ("overlap_fraction", cfg.overlap_fraction),
        ("num_train_samples", cfg.num_train_samples),
        ("num_eval_samples", cfg.num_eval_samples),
        ("num_shuffled_samples", cfg.num_shuffled_samples),
        ("num_epochs", cfg.num_epochs),
        ("num_latent_states", cfg.num_latent_states),
        ("num_rationale_templates", cfg.num_rationale_templates),
        ("d_model", cfg.d_model),
        ("n_layers", cfg.n_layers),
        ("n_heads", cfg.n_heads),
        ("d_ff", cfg.d_ff),
        ("batch_size", cfg.batch_size),
        ("lr", cfg.lr),
        ("consistency_loss_weight", cfg.consistency_loss_weight),
        ("seed", cfg.seed),
    ]
    for k, v in params:
        f.write(f"| {k} | {v} |\n")

    f.write("\n## Results\n\n")
    try:
        f.write(df.to_markdown(index=False))
    except Exception:
        f.write(df.to_string(index=False))

    f.write("\n\n## Strong Coupling Threshold Check\n\n")
    f.write(f"Thresholds: `cls_claim_acc (rationale_pool) > {CLS_THRESHOLD}` "
            f"AND `cfact_cls_follows_swap > {CFACT_THRESHOLD}`\n\n")
    f.write("| Variant | cls_claim_acc | cfact_cls_follows_swap | Meets Thresholds? |\n")
    f.write("|---------|:---:|:---:|:---:|\n")
    for _, row in df.iterrows():
        cls_ok, cfact_ok = _check(row)
        both = cls_ok and cfact_ok
        status = "**YES**" if both else "NO"
        f.write(f"| {row['variant']} | {row['cls_claim_acc (rationale_pool)']:.4f} "
                f"| {row['cfact_cls_follows_swap']:.4f} | {status} |\n")

    f.write("\n## Column Descriptions\n\n")
    f.write("| Column | Description |\n|--------|-------------|\n")
    cols = [
        ("variant", "Training objective variant (pooling mode for consistency loss)"),
        ("final_lm_loss", "Cross-entropy LM loss at end of training"),
        ("final_cons_loss", "Consistency classification loss at end of training"),
        ("gen_claim_acc", "Greedy generation accuracy: first generated token matches expected claim token"),
        ("cls_claim_acc (rationale_pool)", "Classifier accuracy from mean-pooled rationale hidden states"),
        ("cfact_gen_follows_swap", "Rate that generation follows the swapped (wrong) rationale in counterfactual test"),
        ("cfact_gen_follows_orig", "Rate that generation follows the original claim despite swapped rationale"),
        ("cfact_cls_follows_swap", "Rate that classifier follows swapped rationale (strong coupling = high)"),
        ("cfact_cls_follows_orig", "Rate that classifier follows original claim despite swap (low coupling = high)"),
        ("shuffled_gen_acc", "Generation accuracy under shuffled rationale-claim pairings"),
        ("shuffled_cls_acc", "Classifier accuracy under shuffled rationale-claim pairings"),
    ]
    for col, desc in cols:
        f.write(f"| `{col}` | {desc} |\n")

print(f"\n[DONE] Hard experiment results saved to:")
print(f"  CSV: {HARD_CSV}")
print(f"  MD:  {HARD_MD}")
