#!/usr/bin/env python3
"""
eval_embedding_auc.py

Modified to MATCH TASKS ACROSS DATASETS (by task_id / uid) so AUROC isn't driven by
different underlying problems per dataset.

Key change:
- We compute the intersection of task IDs across `--datasets` on the `--eval_split`.
- We restrict:
  (a) eval sampling to tasks whose task_id is in that intersection
  (b) target/negative pools to the same task_id set when possible (falls back w/ warning)
"""

from __future__ import annotations

import argparse
import random
from typing import Any, Dict, List, Optional, Tuple, Set

import numpy as np

from rllm.data.dataset import DatasetRegistry

# Import from your modified embedding module
from examples.bugs.code_embedding import (
    CodeEmbeddingConfig,
    CodeEmbedder,
    KNNBugSimilarity,
)

# ---------------------------
# Task schema helpers
# ---------------------------

def _get_problem(task: Dict[str, Any]) -> str:
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in (
            "question",
            "instruct_prompt",
            "complete_prompt",
            "prompt",
            "text",
            "problem",
            "description",
            "code_prompt",
        ):
            v = extra_info.get(key)
            if isinstance(v, str) and v.strip():
                return v
    for key in (
        "question",
        "instruct_prompt",
        "complete_prompt",
        "prompt",
        "text",
        "problem",
        "description",
        "code_prompt",
    ):
        v = task.get(key)
        if isinstance(v, str) and v.strip():
            return v
    return ""


def _get_buggy_solution(task: Dict[str, Any]) -> Optional[str]:
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("buggy_solution", "buggy_sampled_solution", "buggy", "buggy_code", "bug"):
            v = extra_info.get(key)
            if isinstance(v, str) and v.strip():
                return v
    for key in ("buggy_solution", "buggy_sampled_solution", "buggy", "buggy_code", "bug"):
        v = task.get(key)
        if isinstance(v, str) and v.strip():
            return v
    return None


def _get_reference_solution(task: Dict[str, Any]) -> str:
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("reference_solution", "canonical_solution", "solution", "code", "correct_code"):
            v = extra_info.get(key)
            if isinstance(v, str) and v.strip():
                return v
    for key in ("reference_solution", "canonical_solution", "solution", "code", "correct_code"):
        v = task.get(key)
        if isinstance(v, str) and v.strip():
            return v
    return ""


def _get_task_id(task: Dict[str, Any]) -> Optional[str]:
    """
    Try common task-id keys. Returns a stable string if present.
    """
    extra_info = task.get("extra_info", {})
    candidates = []
    if isinstance(extra_info, dict):
        candidates.extend(
            [
                extra_info.get("task_id"),
                extra_info.get("uid"),
                extra_info.get("id"),
                extra_info.get("problem_id"),
                extra_info.get("instance_id"),
            ]
        )
    candidates.extend(
        [
            task.get("task_id"),
            task.get("uid"),
            task.get("id"),
            task.get("problem_id"),
            task.get("instance_id"),
        ]
    )
    for v in candidates:
        if isinstance(v, (str, int)) and str(v).strip():
            return str(v).strip()
    return None


# ---------------------------
# AUROC (no sklearn dependency)
# ---------------------------

def auc_roc(y_true: List[int], y_score: List[float]) -> float:
    """
    AUROC via Mann–Whitney U / rank statistic.
    y_true: 1 for positive (target), 0 for negative
    """
    if len(y_true) != len(y_score) or len(y_true) == 0:
        return float("nan")

    n_pos = sum(1 for y in y_true if y == 1)
    n_neg = len(y_true) - n_pos
    if n_pos == 0 or n_neg == 0:
        return float("nan")

    # ranks of scores (average ranks for ties)
    order = np.argsort(np.array(y_score))
    ranks = np.empty(len(y_score), dtype=float)
    ranks[order] = np.arange(1, len(y_score) + 1, dtype=float)

    # tie handling: average ranks within each tie group
    scores_sorted = np.array(y_score)[order]
    i = 0
    while i < len(scores_sorted):
        j = i + 1
        while j < len(scores_sorted) and scores_sorted[j] == scores_sorted[i]:
            j += 1
        if j - i > 1:
            avg_rank = float(np.mean(ranks[order[i:j]]))
            ranks[order[i:j]] = avg_rank
        i = j

    # sum ranks for positives
    sum_ranks_pos = float(sum(ranks[i] for i, y in enumerate(y_true) if y == 1))

    # U statistic for positives
    u_pos = sum_ranks_pos - (n_pos * (n_pos + 1)) / 2.0
    auc = u_pos / (n_pos * n_neg)
    return float(auc)


# ---------------------------
# Sampling utilities
# ---------------------------

def load_buggy_tasks(dataset_name: str, split: str) -> List[Dict[str, Any]]:
    ds = DatasetRegistry.load_dataset(dataset_name, split)
    if ds is None:
        raise RuntimeError(f"Could not load dataset={dataset_name!r} split={split!r}")
    data = list(ds.get_data())
    buggy = [t for t in data if _get_buggy_solution(t)]
    return buggy


def sample_tasks(tasks: List[Dict[str, Any]], n: int, rng: random.Random) -> List[Dict[str, Any]]:
    if not tasks:
        return []
    if n >= len(tasks):
        # sample with replacement if you ask for more than available
        return [rng.choice(tasks) for _ in range(n)]
    return rng.sample(tasks, n)


def filter_to_task_ids(tasks: List[Dict[str, Any]], allowed: Set[str]) -> List[Dict[str, Any]]:
    out: List[Dict[str, Any]] = []
    for t in tasks:
        tid = _get_task_id(t)
        if tid is not None and tid in allowed:
            out.append(t)
    return out


def collect_task_ids(tasks: List[Dict[str, Any]]) -> Set[str]:
    s: Set[str] = set()
    for t in tasks:
        tid = _get_task_id(t)
        if tid is not None:
            s.add(tid)
    return s


# ---------------------------
# Main eval
# ---------------------------

def main():
    ap = argparse.ArgumentParser("Eval embedding-based bugbench vs non-bugbench AUROC")
    ap.add_argument("--model", type=str, default="voyage-code-3")
    ap.add_argument("--voyage_api_key", type=str, default=None)

    ap.add_argument("--target_dataset", type=str, default="bugbench")
    ap.add_argument(
        "--datasets",
        type=str,
        nargs="+",
        default=["bugbench", "bugbench_qwen7b_sampled", "bugbench_gpt-oss-20b_sampled"],
    )
    ap.add_argument("--negative_datasets", type=str, nargs="*", default=None)

    ap.add_argument("--pool_split", type=str, default="test_all", help="Split used to build pools (target+negative).")
    ap.add_argument("--eval_split", type=str, default="test", help="Split used to sample eval points.")

    ap.add_argument("--n_per_dataset", type=int, default=200, help="How many eval bugs to sample per dataset.")
    ap.add_argument("--seed", type=int, default=0)

    ap.add_argument("--top_k", type=int, default=20)

    # must match your code_embedding.py config semantics
    ap.add_argument("--embed_mode", type=str, default="diff", choices=["diff", "buggy"])
    ap.add_argument("--include_problem", action="store_true", default=False)
    ap.add_argument("--use_relative_score", action="store_true", default=True)
    ap.add_argument("--no_use_relative_score", action="store_false", dest="use_relative_score")
    ap.add_argument("--margin_temperature", type=float, default=5.0)

    ap.add_argument("--device", type=str, default="cuda")  # for local models only

    # NEW: task matching controls
    ap.add_argument(
        "--match_tasks_across_datasets",
        action="store_true",
        default=True,
        help="Restrict eval (and pools when possible) to task_ids in the intersection across --datasets on eval_split.",
    )
    ap.add_argument(
        "--no_match_tasks_across_datasets",
        action="store_false",
        dest="match_tasks_across_datasets",
    )

    args = ap.parse_args()
    rng = random.Random(args.seed)

    # Decide negatives
    neg_names = args.negative_datasets
    if neg_names is None:
        neg_names = [d for d in args.datasets if d != args.target_dataset]

    print("=" * 80)
    print("Embedding AUROC sanity-check")
    print("=" * 80)
    print(f"model={args.model}")
    print(f"target_dataset={args.target_dataset}  pool_split={args.pool_split}  eval_split={args.eval_split}")
    print(f"datasets={args.datasets}")
    print(f"negative_datasets={neg_names}")
    print(f"n_per_dataset={args.n_per_dataset}  top_k={args.top_k}")
    print(f"embed_mode={args.embed_mode} include_problem={args.include_problem}")
    print(f"use_relative_score={args.use_relative_score} margin_temperature={args.margin_temperature}")
    print(f"match_tasks_across_datasets={args.match_tasks_across_datasets}")
    print()

    # ---------------------------
    # NEW: Compute intersection of task_ids across datasets on eval_split
    # ---------------------------
    matched_task_ids: Optional[Set[str]] = None
    if args.match_tasks_across_datasets:
        per_ds_ids: Dict[str, Set[str]] = {}
        for dn in args.datasets:
            try:
                tasks_eval = load_buggy_tasks(dn, args.eval_split)
            except Exception as e:
                print(f"[WARN] Could not load {dn}:{args.eval_split} for task-id intersection: {e}")
                tasks_eval = []
            ids = collect_task_ids(tasks_eval)
            per_ds_ids[dn] = ids
            print(f"[task-ids] {dn}:{args.eval_split} has {len(ids)} task_ids (from {len(tasks_eval)} buggy rows)")

        # Intersection across all datasets provided
        all_sets = [per_ds_ids[dn] for dn in args.datasets if dn in per_ds_ids]
        if not all_sets or any(len(s) == 0 for s in all_sets):
            print("[WARN] Some dataset has 0 task_ids on eval_split; cannot reliably match tasks. Disabling matching.")
            matched_task_ids = None
        else:
            inter = set.intersection(*all_sets)
            matched_task_ids = inter
            print(f"[task-ids] Intersection across {len(all_sets)} datasets: {len(inter)} task_ids")
            if len(inter) == 0:
                print("[WARN] Task-id intersection is empty. Disabling matching.")
                matched_task_ids = None

        print()

    # Build embedder + scorer
    cfg = CodeEmbeddingConfig(
        model_name=args.model,
        include_problem=bool(args.include_problem),
        embed_mode=str(args.embed_mode),
        top_k=int(args.top_k),
        device=str(args.device),
        voyage_api_key=args.voyage_api_key,
        use_relative_score=bool(args.use_relative_score),
        margin_temperature=float(args.margin_temperature),
    )
    embedder = CodeEmbedder(cfg)
    knn = KNNBugSimilarity(embedder, top_k=args.top_k)

    # ---- Build TARGET pool
    try:
        target_pool_tasks = load_buggy_tasks(args.target_dataset, args.pool_split)
    except Exception as e:
        # common: some datasets may not have train split
        print(f"[WARN] Failed to load target pool split={args.pool_split}: {e}")
        print(f"[WARN] Falling back to pool_split={args.eval_split} (this may leak / inflate AUROC)")
        target_pool_tasks = load_buggy_tasks(args.target_dataset, args.eval_split)

    # Restrict pool to matched task ids when possible
    if matched_task_ids is not None:
        filtered = filter_to_task_ids(target_pool_tasks, matched_task_ids)
        if filtered:
            print(f"[pool] target_pool filtered to matched task_ids: {len(filtered)}/{len(target_pool_tasks)} rows kept")
            target_pool_tasks = filtered
        else:
            print("[WARN] target_pool had 0 rows after task-id filtering; using unfiltered pool (may reintroduce task leakage).")

    knn.build_target_pool(target_pool_tasks)

    # ---- Build NEGATIVE pool (optional)
    if args.use_relative_score and neg_names:
        neg_tasks_all: List[Dict[str, Any]] = []
        for dn in neg_names:
            try:
                nt = load_buggy_tasks(dn, args.pool_split)
            except Exception as e:
                print(f"[WARN] Failed to load negative pool {dn}:{args.pool_split}: {e}")
                print(f"[WARN] Trying {dn}:{args.eval_split} instead (may leak)")
                nt = load_buggy_tasks(dn, args.eval_split)

            if matched_task_ids is not None:
                nt_f = filter_to_task_ids(nt, matched_task_ids)
                if nt_f:
                    nt = nt_f
                else:
                    print(f"[WARN] negative pool {dn} had 0 rows after task-id filtering; keeping unfiltered for this dn.")

            neg_tasks_all.extend(nt)

        if neg_tasks_all:
            knn.build_negative_pool(neg_tasks_all)
        else:
            print("[WARN] negative pool empty; switching to absolute scoring")
            knn.negative_pool = None

    # ---- Sample eval tasks from each dataset (MATCHED)
    eval_points: List[Tuple[str, Dict[str, Any]]] = []
    for dn in args.datasets:
        try:
            tasks = load_buggy_tasks(dn, args.eval_split)
        except Exception as e:
            print(f"[WARN] Skipping {dn}:{args.eval_split} (load failed): {e}")
            continue

        if matched_task_ids is not None:
            before = len(tasks)
            tasks = filter_to_task_ids(tasks, matched_task_ids)
            print(f"[eval] {dn}:{args.eval_split} filtered to matched task_ids: {len(tasks)}/{before} rows")

        samp = sample_tasks(tasks, args.n_per_dataset, rng)
        eval_points.extend([(dn, t) for t in samp])

    if not eval_points:
        raise RuntimeError("No eval points loaded.")

    # ---- Score
    scores: List[float] = []
    labels: List[int] = []  # 1 iff target_dataset
    sources: List[str] = []

    per_source_scores: Dict[str, List[float]] = {}
    per_source_margins: Dict[str, List[float]] = {}

    for dn, t in eval_points:
        problem = _get_problem(t)
        bug = _get_buggy_solution(t) or ""
        corr = _get_reference_solution(t) or None

        s, meta = knn.score_similarity(problem, bug, correct_code=corr)
        m = float(meta.get("margin", 0.0)) if "margin" in meta else float("nan")

        scores.append(float(s))
        labels.append(1 if dn == args.target_dataset else 0)
        sources.append(dn)

        per_source_scores.setdefault(dn, []).append(float(s))
        if "margin" in meta:
            per_source_margins.setdefault(dn, []).append(m)

    # ---- Summary stats
    print("\n" + "=" * 80)
    print("Mean(score) per dataset")
    print("=" * 80)
    for dn in sorted(per_source_scores.keys()):
        arr = np.array(per_source_scores[dn], dtype=float)
        msg = f"{dn:30s}  mean={arr.mean():.3f}  std={arr.std():.3f}  n={len(arr)}"
        if dn in per_source_margins and per_source_margins[dn]:
            marr = np.array(per_source_margins[dn], dtype=float)
            msg += f"  mean_margin={marr.mean():+.4f}"
        print(msg)

    # ---- AUROC
    auc = auc_roc(labels, scores)
    print("\n" + "=" * 80)
    print("AUROC (is this from target_dataset?)")
    print("=" * 80)
    print(f"target_dataset={args.target_dataset}")
    print(f"AUROC={auc:.4f}")

    # Optional: quick bootstrap CI (cheap)
    B = 500
    idxs = list(range(len(scores)))
    boot = []
    for _ in range(B):
        samp = [rng.choice(idxs) for _ in idxs]
        yb = [labels[i] for i in samp]
        sb = [scores[i] for i in samp]
        a = auc_roc(yb, sb)
        if not np.isnan(a):
            boot.append(a)
    if boot:
        lo, hi = np.percentile(np.array(boot), [2.5, 97.5])
        print(f"Bootstrap 95% CI: [{lo:.4f}, {hi:.4f}]  (B={B})")

    print("\n✅ Done")


if __name__ == "__main__":
    main()

