"""
run_claim_only_pooling.py
=========================
Train and evaluate the 'claim_only_pooling' negative-control variant using the
default smoke-test hyperparameters (same as the existing four rows in
results_comparison.csv), then append the result row to both CSV and MD.

Hyperparameters match the defaults in ExperimentConfig:
  num_latent_states=8, num_rationale_templates=4, vocab_size=128
  prompt_len=4, rationale_len=8, claim_len=2
  num_train_samples=512, num_eval_samples=128, num_shuffled_samples=128
  d_model=64, n_heads=4, n_layers=2, d_ff=128, dropout=0.1
  batch_size=32, num_epochs=5, lr=3e-4, consistency_loss_weight=0.5
  seed=42
"""

import sys
import os
sys.path.insert(0, "/home/user/workspace")

import pandas as pd
import torch
from torch.utils.data import DataLoader

from claim_consistency_experiment import (
    ExperimentConfig,
    _build_vocabulary,
    ClaimConsistencyDataset,
    collate_fn,
    set_seed,
    train_one_variant,
    evaluate_claim_accuracy_generation,
    evaluate_claim_accuracy_classifier,
    evaluate_counterfactual_swap,
    evaluate_shuffled_pairing,
)

VARIANT = "claim_only_pooling"
CSV_PATH = "/home/user/workspace/results_comparison.csv"
MD_PATH  = "/home/user/workspace/results_comparison.md"


def main():
    # Default smoke-test config (same values as ExperimentConfig defaults)
    cfg = ExperimentConfig(
        num_latent_states=8,
        num_rationale_templates=4,
        num_train_samples=512,
        num_eval_samples=128,
        num_shuffled_samples=128,
        num_epochs=5,
        batch_size=32,
        d_model=64,
        n_heads=4,
        n_layers=2,
        d_ff=128,
        dropout=0.1,
        lr=3e-4,
        consistency_loss_weight=0.5,
        seed=42,
        pooling_modes=(VARIANT,),  # only train this variant
        results_path=CSV_PATH,
    )

    set_seed(cfg.seed)
    print(f"[INFO] Building vocabulary for {cfg.num_latent_states} latent states ...")
    vocab = _build_vocabulary(cfg)

    print("[INFO] Building datasets ...")
    train_ds = ClaimConsistencyDataset(cfg, vocab, cfg.num_train_samples, shuffled=False, seed_offset=0)
    eval_ds  = ClaimConsistencyDataset(cfg, vocab, cfg.num_eval_samples,  shuffled=False, seed_offset=100)
    shuf_ds  = ClaimConsistencyDataset(cfg, vocab, cfg.num_shuffled_samples, shuffled=True, seed_offset=200)

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,  collate_fn=collate_fn)
    eval_loader  = DataLoader(eval_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)
    shuf_loader  = DataLoader(shuf_ds,  batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

    print(f"\n{'='*60}")
    print(f"[TRAIN] variant={VARIANT}")
    model, history = train_one_variant(cfg, VARIANT, train_loader, eval_loader, shuf_loader, vocab)
    final_lm   = history[-1]["lm_loss"]
    final_cons = history[-1]["cons_loss"]
    print(f"  Final LM loss={final_lm:.4f}  Cons loss={final_cons:.4f}")

    print(f"[EVAL]  variant={VARIANT}")
    gen_acc = evaluate_claim_accuracy_generation(cfg, model, eval_ds, vocab)
    cls_acc = evaluate_claim_accuracy_classifier(cfg, model, eval_ds)
    cfact   = evaluate_counterfactual_swap(cfg, model, vocab, n_samples=64)
    shuf    = evaluate_shuffled_pairing(cfg, model, shuf_ds, vocab)

    row = {
        "variant":                        VARIANT,
        "final_lm_loss":                  round(final_lm, 4),
        "final_cons_loss":                round(final_cons, 4),
        "gen_claim_acc":                  round(gen_acc, 4),
        "cls_claim_acc (rationale_pool)": round(cls_acc, 4),
        "cfact_gen_follows_swap":         round(cfact["gen_follows_swap_rate"], 4),
        "cfact_gen_follows_orig":         round(cfact["gen_follows_orig_rate"], 4),
        "cfact_cls_follows_swap":         round(cfact["cls_follows_swap_rate"], 4),
        "cfact_cls_follows_orig":         round(cfact["cls_follows_orig_rate"], 4),
        "shuffled_gen_acc":               round(shuf["shuffled_gen_acc"], 4),
        "shuffled_cls_acc":               round(shuf["shuffled_cls_acc"], 4),
    }
    print("\n[RESULT]", row)

    # ---- Append to CSV ----
    existing_df = pd.read_csv(CSV_PATH)
    # Remove any pre-existing claim_only_pooling row to avoid duplicates
    existing_df = existing_df[existing_df["variant"] != VARIANT]
    new_df = pd.concat([existing_df, pd.DataFrame([row])], ignore_index=True)
    new_df.to_csv(CSV_PATH, index=False)
    print(f"[SAVED] CSV → {CSV_PATH}")

    # ---- Rebuild MD ----
    md_lines = ["# Claim-Consistency Coupling — Results", ""]
    md_lines.append(new_df.to_markdown(index=False))
    md_lines.append("")
    with open(MD_PATH, "w") as f:
        f.write("\n".join(md_lines))
    print(f"[SAVED] MD  → {MD_PATH}")

    print("\n=== FINAL TABLE ===")
    print(new_df.to_string(index=False))

    return row


if __name__ == "__main__":
    result = main()
