"""
run_experiment.py — Main entry point for the consistency loss experiment.

Usage:
    python run_experiment.py --smoke          # Fast smoke test (~2-5 min CPU)
    python run_experiment.py --full           # Full run (20 epochs, 3000 examples)
    python run_experiment.py --full --model small --epochs 20 --batch 32
    python run_experiment.py --full --model gpt2_small  # GPU required

All outputs go to outputs/ in the current directory:
    outputs/metrics.csv
    outputs/coupling_strength.png
    outputs/explanation_correctness.png
    outputs/counterfactual_swap.png
    outputs/claim_accuracy.png
    outputs/losses.png
    outputs/report.pplx.md
    outputs/checkpoints/<variant>/epoch_NNN.pt
"""

import os
import sys
import time
import argparse
import random
import json

import numpy as np
import pandas as pd
import torch

# ── Project imports ────────────────────────────────────────────────────────────
from dataset import build_dataset, split_dataset, build_tokenizer, make_target_sequence
from model import ConsistencyTransformer, TransformerConfig
from trainer import (
    TrainConfig, train_variant, VARIANTS, VARIANTS_V1, VARIANTS_V2,
    generate_epoch_samples, collect_qualitative_examples,
    CodeExplanationDataset,
)
from visualize import generate_all_charts
from report import build_report


def parse_args():
    p = argparse.ArgumentParser(description="Consistency Loss Experiment")

    mode = p.add_mutually_exclusive_group(required=True)
    mode.add_argument("--smoke", action="store_true",
                      help="Fast smoke test (tiny model, few epochs)")
    mode.add_argument("--full",  action="store_true",
                      help="Full configured run (3000 examples, 20 epochs)")

    p.add_argument("--model", type=str, default=None,
                   choices=["smoke", "small", "gpt2_small"],
                   help="Override model size (default: smoke for --smoke, small for --full)")
    p.add_argument("--epochs",  type=int, default=None, help="Override number of epochs")
    p.add_argument("--batch",   type=int, default=None, help="Override batch size")
    p.add_argument("--n",       type=int, default=None, help="Override number of examples")
    p.add_argument("--seed",    type=int, default=42)
    p.add_argument("--variants", nargs="+", default=None,
                   choices=VARIANTS,
                   help=("Only run specific variants. V1 (original): consistency_loss, "
                         "no_consistency_loss, claim_only_pooling, random_label_consistency. "
                         "V2 (stronger ablations): no_claim_to_claim_attention, "
                         "claims_from_explanation_only, surface_bottleneck_consistency, "
                         "surface_bottleneck_no_expl_lm. Default: all 8 variants."))
    p.add_argument("--v1-only", action="store_true",
                   help="Only run V1 (original 4) variants")
    p.add_argument("--v2-only", action="store_true",
                   help="Only run V2 (new 4 stronger ablation) variants")
    p.add_argument("--output-dir", type=str, default="outputs")
    p.add_argument("--no-charts",  action="store_true", help="Skip chart generation")
    p.add_argument("--smoke-epochs", type=int, default=5,
                   help="Epochs for smoke run (default 5)")
    p.add_argument("--smoke-n",  type=int, default=300,
                   help="Dataset size for smoke run (default 300)")
    p.add_argument("--smoke-batch", type=int, default=16,
                   help="Batch size for smoke run (default 16)")
    return p.parse_args()


def main():
    args = parse_args()

    # ── Output directory ───────────────────────────────────────────────────────
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)

    # ── Build TrainConfig ──────────────────────────────────────────────────────
    cfg = TrainConfig(
        seed=args.seed,
        output_dir=output_dir,
        smoke=args.smoke,
        smoke_n=args.smoke_n,
        smoke_epochs=args.smoke_epochs,
        smoke_batch=args.smoke_batch,
    )

    if args.full:
        cfg.smoke = False
        if args.model:   cfg.model_size = args.model
        if args.epochs:  cfg.n_epochs   = args.epochs
        if args.batch:   cfg.batch_size = args.batch
        if args.n:       cfg.n_examples = args.n

    eff = cfg.effective()

    # Apply overrides to effective config
    if args.model:   eff["model_size"] = args.model
    if args.epochs:  eff["n_epochs"]   = args.epochs
    if args.batch:   eff["batch_size"] = args.batch
    if args.n:       eff["n_examples"] = args.n

    is_smoke = args.smoke
    if args.variants:
        variants_to_run = args.variants
    elif getattr(args, 'v1_only', False):
        variants_to_run = VARIANTS_V1
    elif getattr(args, 'v2_only', False):
        variants_to_run = VARIANTS_V2
    else:
        variants_to_run = VARIANTS

    print("=" * 70)
    print("  Consistency Loss Experiment")
    print("=" * 70)
    print(f"  Mode:          {'SMOKE' if is_smoke else 'FULL'}")
    print(f"  Dataset size:  {eff['n_examples']}")
    print(f"  Validation:    {eff['val_size']}")
    print(f"  Epochs:        {eff['n_epochs']}")
    print(f"  Batch size:    {eff['batch_size']}")
    print(f"  Model size:    {eff['model_size']}")
    print(f"  Max seq len:   {eff['max_seq_len']}")
    print(f"  Variants:      {variants_to_run}")
    print(f"  Output dir:    {os.path.abspath(output_dir)}")
    print("=" * 70)

    # ── Device ────────────────────────────────────────────────────────────────
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"  Device: {device}")

    # ── Reproducibility ───────────────────────────────────────────────────────
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    # ── Dataset ───────────────────────────────────────────────────────────────
    t0 = time.time()
    print("\nBuilding dataset...")
    examples = build_dataset(n=eff["n_examples"], seed=cfg.seed)
    train_ex, val_ex = split_dataset(examples, val_size=eff["val_size"], seed=cfg.seed)
    print(f"  Train: {len(train_ex)}, Val: {len(val_ex)}")

    print("Building tokenizer...")
    tokenizer = build_tokenizer(examples)
    print(f"  Vocab size: {tokenizer.vocab_size}")
    print(f"  Dataset built in {time.time() - t0:.1f}s")

    # ── Training ──────────────────────────────────────────────────────────────
    all_records = []

    for variant in variants_to_run:
        records = train_variant(
            variant=variant,
            train_examples=train_ex,
            val_examples=val_ex,
            tokenizer=tokenizer,
            cfg=cfg,
            device=device,
            eff=eff,
        )
        all_records.extend(records)

    # ── Save CSV ──────────────────────────────────────────────────────────────
    df = pd.DataFrame(all_records)
    csv_path = os.path.join(output_dir, "metrics.csv")
    df.to_csv(csv_path, index=False)
    print(f"\nMetrics saved: {csv_path}")

    # ── Charts ────────────────────────────────────────────────────────────────
    if not args.no_charts:
        generate_all_charts(df, output_dir)

    # ── Qualitative examples ──────────────────────────────────────────────────
    print("\nCollecting qualitative examples...")

    # Get the consistency_loss model checkpoint (last epoch)
    n_epochs = eff["n_epochs"]
    ckpt_dir = os.path.join(output_dir, "checkpoints", "consistency_loss")

    # Find most recent checkpoint
    qual_examples = []
    epoch_1_gens = []
    epoch_final_gens = []

    if "consistency_loss" in variants_to_run:
        # Load epoch-final checkpoint
        ckpt_files = sorted([f for f in os.listdir(ckpt_dir) if f.endswith(".pt")])
        if ckpt_files:
            final_ckpt_path = os.path.join(ckpt_dir, ckpt_files[-1])
            first_ckpt_path = os.path.join(ckpt_dir, ckpt_files[0])

            def load_model_from_ckpt(path):
                ckpt = torch.load(path, map_location=device, weights_only=False)
                model_cfg = ckpt["model_cfg"]
                m = ConsistencyTransformer(model_cfg).to(device)
                m.load_state_dict(ckpt["model_state"])
                m.eval()
                return m

            from trainer import generate_epoch_samples, CodeExplanationDataset

            # Final epoch model
            model_final = load_model_from_ckpt(final_ckpt_path)
            val_ds = CodeExplanationDataset(val_ex, tokenizer, eff["max_seq_len"])
            epoch_final_gens = generate_epoch_samples(
                model_final, val_ex, tokenizer, device, n=min(15, len(val_ex))
            )

            # First epoch model (or approximate with first checkpoint)
            model_first = load_model_from_ckpt(first_ckpt_path)
            epoch_1_gens = generate_epoch_samples(
                model_first, val_ex, tokenizer, device, n=min(15, len(val_ex))
            )

            qual_examples = collect_qualitative_examples(epoch_1_gens, epoch_final_gens)
            # Add template name from val_ex for display
            for qe in qual_examples:
                idx = qe.get("idx", 0)
                if idx < len(val_ex):
                    qe["template_name"] = val_ex[idx].template_name

    # ── Report ────────────────────────────────────────────────────────────────
    report_path = os.path.join(output_dir, "report.pplx.md")
    build_report(
        df=df,
        qualitative_examples=qual_examples,
        cfg_dict=eff,
        is_smoke=is_smoke,
        output_path=report_path,
    )

    # ── Summary ───────────────────────────────────────────────────────────────
    total_time = time.time() - t0

    print("\n" + "=" * 70)
    print("  EXPERIMENT COMPLETE")
    print("=" * 70)

    # Print final metrics for each variant
    if not df.empty:
        max_epoch = df["epoch"].max()
        final = df[df["epoch"] == max_epoch]
        print(f"\n  Final epoch ({max_epoch}) metrics:")
        print(f"  {'Variant':<38} {'Coupling':>9} {'BLEU-1':>7} {'ROUGE-L':>8} {'ClaimAcc':>9}")
        print(f"  {'-'*38} {'-'*9} {'-'*7} {'-'*8} {'-'*9}")
        for v in VARIANTS:
            row = final[final["variant"] == v]
            if row.empty:
                continue
            r = row.iloc[0]
            try:
                print(f"  {v:<38} {r['val_coupling_strength']:9.4f} {r['val_bleu1']:7.4f} "
                      f"{r['val_rouge_l']:8.4f} {r['val_claim_accuracy']:9.4f}")
            except KeyError:
                pass

    print(f"\n  Total time: {total_time:.1f}s ({total_time/60:.1f} min)")
    print(f"\n  Output files:")

    output_files = {
        "Metrics CSV":   csv_path,
        "Report":        report_path,
        "Coupling chart": os.path.join(output_dir, "coupling_strength.png"),
        "Expl. correctness chart": os.path.join(output_dir, "explanation_correctness.png"),
        "Swap influence chart": os.path.join(output_dir, "counterfactual_swap.png"),
        "Claim accuracy chart": os.path.join(output_dir, "claim_accuracy.png"),
        "Loss curves chart": os.path.join(output_dir, "losses.png"),
    }
    for label, path in output_files.items():
        exists = "✓" if os.path.exists(path) else "✗"
        print(f"  [{exists}] {label}: {os.path.abspath(path)}")

    print("\n  Checkpoints:")
    for v in variants_to_run:
        ckpt_d = os.path.join(output_dir, "checkpoints", v)
        if os.path.isdir(ckpt_d):
            ckpts = [f for f in os.listdir(ckpt_d) if f.endswith(".pt")]
            print(f"    {v}: {len(ckpts)} checkpoints in {os.path.abspath(ckpt_d)}")

    print("=" * 70)

    # Save run metadata
    meta = {
        "mode": "smoke" if is_smoke else "full",
        "config": eff,
        "variants": variants_to_run,
        "total_time_s": total_time,
        "device": str(device),
        "n_epochs_run": int(df["epoch"].max()) if not df.empty else 0,
        "output_files": {k: os.path.abspath(v) for k, v in output_files.items()},
    }
    meta_path = os.path.join(output_dir, "run_metadata.json")
    with open(meta_path, "w") as f:
        json.dump(meta, f, indent=2)
    print(f"\n  Run metadata: {meta_path}")


if __name__ == "__main__":
    main()
