#!/usr/bin/env python3
import os
import sys
import json
import random
import argparse
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Ensure project root is on sys.path
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from debate import construct_question_prompt
from parser import parse_answer, grade_answer
from data import get_data
from model import Agent


def load_experiment_config(exp_dir: Path) -> Dict:
    cfg_path = exp_dir / "config.json"
    if not cfg_path.exists():
        raise FileNotFoundError(f"Missing config.json in {exp_dir}")
    with open(cfg_path, 'r') as f:
        return json.load(f)


def load_test_examples(dataset: str, train_size: int, test_size: int, seed: int) -> List[Dict]:
    full = get_data(dataset_name=dataset, train_size=train_size + test_size, test_size=0, seed=seed)
    # Take last test_size examples as the original test set
    return full[train_size:train_size + test_size]


def select_prompts(test_examples: List[Dict], k: int, seed: int) -> List[Dict]:
    rng = random.Random(seed)
    if k <= 0 or k >= len(test_examples):
        return list(test_examples)
    return rng.sample(test_examples, k)


def batch_sample_completions(agent: Agent, prompt: str, dataset: str, n: int, temperature: float, top_p: float) -> List[str]:
    import asyncio
    context = {"role": "user", "content": construct_question_prompt(prompt, dataset)}
    device = agent.get_device()

    # Greedy: generate once then replicate to avoid redundant compute
    if temperature == 0.0:
        outs = asyncio.run(agent.batch_generate(
            contexts_list=[[context]],
            device=device,
            temperature=0.0,
            top_p=top_p,
        ))
        text = outs[0]["choices"][0]["message"]["content"]
        return [text for _ in range(n)]

    # Sampling: generate n samples in one batch
    contexts = [[context] for _ in range(n)]
    outs = asyncio.run(agent.batch_generate(
        contexts_list=contexts,
        device=device,
        temperature=temperature,
        top_p=top_p,
    ))
    texts = [o["choices"][0]["message"]["content"] for o in outs]
    return texts


def compute_pass_curve(sampled_texts: List[str], ground_truth: str, dataset: str) -> List[int]:
    # Return cumulative "any correct so far" for t=1..n (0/1 per t)
    correctness = []
    seen_correct = False
    for txt in sampled_texts:
        parsed = parse_answer(txt, dataset=dataset)
        is_correct = grade_answer(parsed, ground_truth)
        seen_correct = seen_correct or is_correct
        correctness.append(1 if seen_correct else 0)
    return correctness


def average_curves(curves: List[List[int]], n: int) -> np.ndarray:
    if not curves:
        return np.zeros(n)
    arr = np.array([c[:n] for c in curves], dtype=float)
    return arr.mean(axis=0)


def ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


def main():
    parser = argparse.ArgumentParser(description="Self-consistency pass@k curves for base vs post-trained agent 0")
    parser.add_argument("experiment", help="Experiment ID or path, e.g., 3244874366171575532 or experiments/324487...")
    parser.add_argument("-n", "--n_samples", type=int, default=20, help="Number of samples per prompt")
    parser.add_argument("-k", "--k_prompts", type=int, default=20, help="Number of prompts from the test set (<=0 for full test)")
    parser.add_argument("--seed", type=int, default=0, help="Sampling seed for prompt selection")
    parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature for the main curves")
    parser.add_argument("--top_p", type=float, default=0.9)
    parser.add_argument("--also_greedy", action="store_true", help="Additionally compute greedy (temp=0.0) curves")
    parser.add_argument("--which", choices=["both", "base", "post"], default="both", help="Which agent(s) to evaluate")
    parser.add_argument("--quantize_base", action="store_true", help="Load base model in 4-bit quantized mode")
    parser.add_argument("--quantize_post", action="store_true", help="Load post-trained model in 4-bit quantized mode")
    parser.add_argument("--device", type=str, default="auto", help="Device map for model (e.g., cuda:0, cuda:1)")
    parser.add_argument("--save_completions", action="store_true", help="Save per-sample completions to JSONL for MV@t analysis")
    args = parser.parse_args()

    # Resolve experiment directory
    exp_dir = Path(args.experiment)
    if not exp_dir.exists():
        exp_dir = Path("experiments") / args.experiment
    if not exp_dir.exists():
        raise FileNotFoundError(f"Experiment directory not found: {args.experiment}")

    cfg = load_experiment_config(exp_dir)
    model_name = cfg.get("model")
    dataset = cfg.get("dataset")
    train_size = int(cfg.get("train_size", 1500))
    test_size = int(cfg.get("test_size", 500))
    data_seed = int(cfg.get("data_seed", cfg.get("seed", 0)))

    # Load agents according to 'which'
    base_agent: Optional[Agent] = None
    post_agent: Optional[Agent] = None

    if args.which in ("both", "base"):
        print(f"Loading BASE model ({'quantized' if args.quantize_base else 'full'}) on {args.device}: {model_name}")
        base_agent = Agent(model_name, checkpoint_path=None, agent_id=0, device_map=args.device, task=dataset, quantization=args.quantize_base)
        print(f"Base model on device: {base_agent.get_device()}")

    if args.which in ("both", "post"):
        ckpt_dir = exp_dir / "checkpoints" / "agent_0"
        post_ckpt = str(ckpt_dir) if ckpt_dir.exists() else None
        if post_ckpt:
            print(f"Loading POST-TRAINED agent 0 ({'quantized' if args.quantize_post else 'full'}) from {post_ckpt} on {args.device}")
        else:
            print("Post-trained checkpoint for agent 0 not found; using base weights for post-trained curve")
        post_agent = Agent(model_name, checkpoint_path=post_ckpt, agent_id=0, device_map=args.device, task=dataset, quantization=args.quantize_post)
        print(f"Post-trained model on device: {post_agent.get_device()}")

    # Load test set and choose prompts
    print(f"Loading test split: dataset={dataset}, train_size={train_size}, test_size={test_size}, seed={data_seed}")
    test_examples = load_test_examples(dataset, train_size, test_size, data_seed)
    prompts = select_prompts(test_examples, args.k_prompts, args.seed)
    print(f"Selected {len(prompts)} prompts for evaluation")

    n = args.n_samples

    def completion_log_path(which_key: str, quantized: bool, temp: float) -> Path:
        out_dir = exp_dir / "plots"
        ensure_dir(out_dir)
        suf = [f"n{n}", f"k{len(prompts)}", f"temp{temp}", which_key]
        if quantized:
            suf.append("quant")
        fname = f"self_consistency_samples_{'_'.join(suf)}.jsonl"
        return out_dir / fname

    def compute_avg_curve_for_temp(agent: Agent, which_key: str, quantized: bool, temp: float) -> np.ndarray:
        curves: List[List[int]] = []
        out_f = None
        out_path: Optional[Path] = None
        if args.save_completions:
            out_path = completion_log_path(which_key, quantized, temp)
            print(f"[cache] Writing per-sample completions to: {out_path}")
            out_f = open(out_path, "w", encoding="utf-8")
        try:
            for idx, ex in enumerate(prompts, 1):
                q = ex['question']
                gt = ex['answer']
                mode = 'greedy' if temp == 0.0 else 'sampling'
                print(f"[{idx}/{len(prompts)}] {which_key} {('4-bit' if quantized else 'full')} {mode}: sampling {n} completions")
                texts = batch_sample_completions(agent, q, dataset, n, temp, args.top_p)
                # Save per-sample if requested
                if out_f is not None:
                    parsed = [parse_answer(t, dataset=dataset) for t in texts]
                    correct = [grade_answer(p, gt) for p in parsed]
                    rec = {
                        "experiment": str(exp_dir.name),
                        "model": model_name,
                        "dataset": dataset,
                        "which": which_key,
                        "quantized": quantized,
                        "temperature": temp,
                        "n": n,
                        "k": len(prompts),
                        "prompt_index": idx - 1,
                        "question": q,
                        "ground_truth": gt,
                        "texts": texts,
                        "parsed": parsed,
                        "correct": correct,
                    }
                    out_f.write(json.dumps(rec) + "\n")
                    out_f.flush()
                    print(f"[cache] Wrote prompt {idx}/{len(prompts)} to {out_path}")
                curves.append(compute_pass_curve(texts, gt, dataset))
        finally:
            if out_f is not None:
                out_f.close()
                print(f"[cache] Closed {out_path}")
        return average_curves(curves, n)

    # Compute according to selection
    base_avg = base_avg_greedy = None
    post_avg = post_avg_greedy = None

    if base_agent is not None:
        base_avg = compute_avg_curve_for_temp(base_agent, "base", args.quantize_base, args.temperature)
        if args.also_greedy:
            base_avg_greedy = compute_avg_curve_for_temp(base_agent, "base", args.quantize_base, 0.0)

    if post_agent is not None:
        post_avg = compute_avg_curve_for_temp(post_agent, "post", args.quantize_post, args.temperature)
        if args.also_greedy:
            post_avg_greedy = compute_avg_curve_for_temp(post_agent, "post", args.quantize_post, 0.0)

    # Save CSV
    out_dir = exp_dir / "plots"
    ensure_dir(out_dir)

    parts = [f"n{n}", f"k{len(prompts)}", f"temp{args.temperature}", args.which]
    if args.quantize_base:
        parts.append("qb")
    if args.quantize_post:
        parts.append("qp")
    if args.also_greedy:
        parts.append("with_greedy")
    suffix = "_".join(parts)

    csv_path = out_dir / f"self_consistency_passk_{suffix}.csv"
    with open(csv_path, 'w') as f:
        header = ["t"]
        if base_avg is not None:
            header.append("base_pass")
        if post_avg is not None:
            header.append("post_pass")
        if base_avg_greedy is not None:
            header.append("base_greedy")
        if post_avg_greedy is not None:
            header.append("post_greedy")
        f.write(",".join(header) + "\n")
        for t in range(1, n + 1):
            row = [t]
            if base_avg is not None:
                row.append(base_avg[t-1])
            if post_avg is not None:
                row.append(post_avg[t-1])
            if base_avg_greedy is not None:
                row.append(base_avg_greedy[t-1])
            if post_avg_greedy is not None:
                row.append(post_avg_greedy[t-1])
            f.write(",".join(str(v) for v in row) + "\n")
    print(f"Saved CSV: {csv_path}")

    # Plot (kept for convenience; combined plot handled separately)
    fig, ax = plt.subplots(figsize=(7.5, 4.8))
    x = np.arange(1, n + 1)
    if base_avg is not None:
        ax.plot(x, base_avg, label=f"Base ({'quant' if args.quantize_base else 'full'}, temp={args.temperature})", color="#1f77b4", marker="o")
    if post_avg is not None:
        ax.plot(x, post_avg, label=f"Post ({'quant' if args.quantize_post else 'full'}, temp={args.temperature})", color="#d62728", marker="s")
    if base_avg_greedy is not None:
        ax.plot(x, base_avg_greedy, label="Base (greedy)", color="#1f77b4", linestyle="--")
    if post_avg_greedy is not None:
        ax.plot(x, post_avg_greedy, label="Post (greedy)", color="#d62728", linestyle="--")
    ax.set_xlabel("Number of samples (t)")
    ax.set_ylabel("Pass@t")
    title = f"Self-consistency: {model_name} on {dataset} (k={len(prompts)}, n={n})"
    ax.set_title(title)
    ax.grid(True, linestyle=":", alpha=0.6)
    ax.set_xlim(1, n)
    y_all = []
    for arr in [base_avg, post_avg, base_avg_greedy, post_avg_greedy]:
        if isinstance(arr, np.ndarray):
            y_all.append(arr)
    if y_all:
        ycat = np.concatenate(y_all)
        ymin, ymax = float(np.nanmin(ycat)), float(np.nanmax(ycat))
        pad = max(1e-3, 0.02 * (ymax - ymin if ymax > ymin else 1.0))
        ax.set_ylim(ymin - pad, ymax + pad)
    ax.legend(loc="lower right")
    fig.tight_layout()
    fig_path = out_dir / f"self_consistency_passk_{suffix}.png"
    fig.savefig(fig_path, dpi=200)
    print(f"Saved figure: {fig_path}")


if __name__ == "__main__":
    main()
