# rvb_eval_core.py
# -*- coding: utf-8 -*-
"""
Core-suite evaluator for Reward Variance on a fixed dataset.

What it does:
  - Runs RewardBench on a fixed core dataset (e.g., eval_core.jsonl).
  - Loads per-candidate scores from each reward model (RM).
  - Produces cross-model comparable, relative metrics at three levels:
      * by domain
      * by (domain, task)
      * by prompt
    using model-internal z-score and quantile (rank-CDF) normalization.
  - Computes cross-model consensus (ranking correlation) per prompt.

Outputs (CSV under OUT_ROOT):
  - core_metrics_by_domain.csv
  - core_metrics_by_group.csv      (domain, task)
  - core_metrics_prompts.csv       (per RM per prompt)
  - core_consensus_by_prompt.csv   (pairwise Kendall tau / Spearman rho)
  - core_model_summary.csv         (global stats per RM)
  - all_scores_merged.parquet      (optional raw merge for further analysis)

Notes:
  - Robustly parses RewardBench outputs from:
      * *_outputs.jsonl (with "results" or "score" field)
      * scores.jsonl    (local scorer format)
  - Relative metrics are computed as:
      * z-score within each RM: z = (r - mu) / sigma
      * quantile within each RM: q in [0,1] via rank / (N-1)
      -> This yields RSI_z, BW_z, RSI_q, BW_q.
  - Kendall's tau requires scipy; if missing, falls back to Spearman's rho.
"""

import os
import sys
import json
import math
import time
import glob
import subprocess as sp
from collections import defaultdict
from typing import List, Dict, Any, Tuple

import numpy as np
import pandas as pd

# Try to import scipy for Kendall's tau; fall back to numpy-based Spearman's rho
try:
    from scipy.stats import kendalltau, rankdata, spearmanr
    HAVE_SCIPY = True
except Exception:
    HAVE_SCIPY = False


# ===================== CONFIG (edit here) =====================
CONFIG = {
    # Path to your core dataset (with a fixed number of candidates per prompt).
    "CORE_DATA": "./data/eval_core.jsonl",

    # Root directory for outputs (the script will create subdirectories).
    "OUT_ROOT": "core_eval_runs",

    # RMs to evaluate (Hugging Face models that can be loaded by RewardBench).
    # It's recommended to include models from different families.
    "RB_MODELS": [
        "CIR-AMS/BTRM_Qwen2_7b_0613",                      # Qwen family
        "nicolinho/QRM-Llama3.1-8B-v2",                  # Llama family
        "OpenAssistant/reward-model-deberta-v3-large-v2" # DeBERTa family (SeqClassifier)
    ],

    # RewardBench execution parameters
    "RB_BATCH_SIZE": 8,
    "RB_CHAT_TEMPLATE": "raw",      # 'raw' is the most stable template for most RMs.
    "RB_ENTRY": "rewardbench",      # Name of the executable in your system's PATH.
    "RB_EXTRA_FLAGS": [],           # e.g., ["--not_quantized"] to resolve bnb/GLIBC issues.

    # If historical results exist, skip re-evaluation.
    "SKIP_IF_DONE": True,

    # Whether to save the merged raw scores in Parquet format.
    "SAVE_MERGED_PARQUET": True,
}
# =============================================================


def log(msg: str):
    ts = time.strftime("%Y-%m-%d %H:%M:%S")
    print(f"[{ts}] {msg}", flush=True)


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)


def sanitize_model_id(mid: str) -> str:
    return mid.replace("/", "_").replace(":", "_")


def run_rewardbench(core_path: str, model_id: str, out_root: str,
                    batch_size: int, chat_template: str, entry: str, extra_flags: List[str],
                    skip_if_done: bool = True) -> str:
    """
    Run RewardBench for a single HF reward model on a given JSONL file.
    Returns the model's output directory path.
    """
    tag = sanitize_model_id(model_id)
    out_dir = os.path.join(out_root, f"rb_{tag}")
    ensure_dir(out_dir)

    # If outputs already exist, skip (best-effort detection).
    if skip_if_done:
        candidates = []
        candidates += glob.glob(os.path.join(out_dir, "**/*outputs.jsonl"), recursive=True)
        candidates += glob.glob(os.path.join(out_dir, "**/scores.jsonl"), recursive=True)
        if len(candidates) > 0:
            log(f"[RB] SKIP (found outputs) model={model_id} -> {out_dir}")
            return out_dir

    cmd = [
        entry,
        f"--model={model_id}",
        f"--dataset={core_path}",
        "--load_json",
        f"--chat_template={chat_template}",
        f"--batch_size={batch_size}",
        f"--output_dir={out_dir}",
    ] + list(extra_flags)

    log("[CMD] " + " ".join(cmd))
    try:
        res = sp.run(cmd, stdout=sp.PIPE, stderr=sp.STDOUT, text=True, check=False)
        print(res.stdout)
        if res.returncode != 0:
            log(f"[RB] FAILED model={model_id} (exit={res.returncode})")
        else:
            log(f"[RB] DONE model={model_id} -> {out_dir}")
    except Exception as e:
        log(f"[RB] EXCEPTION model={model_id}: {e}")

    return out_dir


def _try_parse_line(d: Dict[str, Any]) -> Tuple[str, int, float]:
    """
    Tries to extract (prompt_id, cand_index, score) from a result JSON line
    produced by RewardBench, handling multiple common formats.
    """
    # Extract score
    score = None
    if "results" in d and isinstance(d["results"], (int, float)):
        score = float(d["results"])
    elif "score" in d and isinstance(d["score"], (int, float)):
        score = float(d["score"])
    else:
        # Some models might nest the score, e.g., {"results": {"score": x}}.
        try:
            if isinstance(d.get("results"), dict):
                score = float(d["results"].get("score"))
        except Exception:
            pass

    # Extract prompt_id and cand_index
    pid = None
    cid = None

    # Common format: 'meta' dictionary carries identifiers.
    meta = d.get("meta") or {}
    if isinstance(meta, dict):
        if "prompt_id" in meta:
            pid = str(meta["prompt_id"])
        if "cand_index" in meta:
            try:
                cid = int(meta["cand_index"])
            except Exception: # Fallback for string-like numbers, e.g., "3".
                try:
                    cid = int(str(meta["cand_index"]))
                except Exception:
                    cid = None

    # Fallback: Check for top-level fields.
    if pid is None:
        pid = str(d.get("prompt_id")) if d.get("prompt_id") is not None else None
    if cid is None:
        try:
            cid = int(d.get("cand_index"))
        except Exception:
            cid = None

    if pid is None or cid is None or score is None:
        raise ValueError("Missing one of (prompt_id, cand_index, score) in line")

    return pid, cid, score


def load_rb_scores(out_dir: str, model_id: str) -> pd.DataFrame:
    """
    Loads RewardBench outputs for a model. Returns a DataFrame with columns:
    [model, prompt_id, cand_index, score].
    Supports both *_outputs.jsonl and scores.jsonl formats.
    """
    paths = []
    paths += glob.glob(os.path.join(out_dir, "**/*outputs.jsonl"), recursive=True)
    paths += glob.glob(os.path.join(out_dir, "**/scores.jsonl"), recursive=True)
    if not paths:
        log(f"[WARN] No output files found in {out_dir}")
        return pd.DataFrame(columns=["model", "prompt_id", "cand_index", "score"])

    rows = []
    bad_lines = 0
    for p in paths:
        with open(p, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    d = json.loads(line)
                    pid, cid, score = _try_parse_line(d)
                    rows.append((model_id, pid, cid, float(score)))
                except Exception:
                    bad_lines += 1
                    continue
    if bad_lines > 0:
        log(f"[WARN] {model_id}: skipped {bad_lines} malformed lines")
    df = pd.DataFrame(rows, columns=["model", "prompt_id", "cand_index", "score"])
    return df


def load_core(core_path: str) -> pd.DataFrame:
    """
    Loads the core dataset (one record per candidate).
    Expected fields: prompt_id, cand_index, domain, task.
    """
    rows = []
    with open(core_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            d = json.loads(line)
            pid = str(d.get("prompt_id"))
            cid = int(d.get("cand_index")) if "cand_index" in d else int(str(d["meta"]["cand_index"]))
            domain = d.get("domain") or d.get("meta", {}).get("domain")
            task = d.get("task") or d.get("meta", {}).get("task")
            rows.append((pid, cid, domain, task))
    core = pd.DataFrame(rows, columns=["prompt_id", "cand_index", "domain", "task"]).drop_duplicates()
    return core


def add_relative_columns(scores: pd.DataFrame) -> pd.DataFrame:
    """
    Adds model-internal z-score and quantile-normalized columns.
      - z = (score - mu) / sigma  (per model, across all candidates)
      - q = rank / (N-1)          (per model, ascending)
    """
    df = scores.copy()
    # z-score per model
    mu_sigma = df.groupby("model")["score"].agg(["mean", "std"]).reset_index()
    df = df.merge(mu_sigma, on="model", how="left")
    # Avoid division by zero
    df["std"] = df["std"].replace(0.0, np.nan)
    df["z"] = (df["score"] - df["mean"]) / df["std"]
    df["z"] = df["z"].fillna(0.0)
    df.drop(columns=["mean", "std"], inplace=True)

    # quantile rank per model
    def _to_quantile(x: pd.Series) -> pd.Series:
        # Rank from smallest score (0.0) to largest (1.0)
        r = x.rank(method="average")  # Ranks are from 1 to N
        n = len(r)
        if n <= 1:
            return pd.Series([0.5] * n, index=x.index)
        return (r - 1) / (n - 1)

    df["q"] = df.groupby("model")["score"].transform(_to_quantile)
    return df


def per_prompt_metrics(df_rel: pd.DataFrame, core_meta: pd.DataFrame) -> pd.DataFrame:
    """
    Computes metrics for each RM × prompt combination.
      - RSI_raw: max(score) - min(score)
      - BW_raw: P90(score) - P10(score)
      - The same metrics are calculated for z-scores and quantiles (z, q).
    Returns a DataFrame of per-prompt metrics merged with domain/task metadata.
    """
    def _q10(x): return np.quantile(x, 0.10)
    def _q90(x): return np.quantile(x, 0.90)

    g = df_rel.groupby(["model", "prompt_id"])
    met = g.agg(
        n_cand=("score", "size"),
        score_min=("score", "min"),
        score_max=("score", "max"),
        RSI_raw=("score", lambda s: float(np.max(s) - np.min(s))),
        P10_raw=("score", _q10),
        P90_raw=("score", _q90),
        BW_raw=("score", lambda s: float(np.quantile(s, 0.90) - np.quantile(s, 0.10))),

        RSI_z=("z", lambda s: float(np.nanmax(s) - np.nanmin(s))),
        P10_z=("z", _q10),
        P90_z=("z", _q90),
        BW_z=("z", lambda s: float(np.quantile(s, 0.90) - np.quantile(s, 0.10))),

        RSI_q=("q", lambda s: float(np.nanmax(s) - np.nanmin(s))),
        P10_q=("q", _q10),
        P90_q=("q", _q90),
        BW_q=("q", lambda s: float(np.quantile(s, 0.90) - np.quantile(s, 0.10))),
    ).reset_index()

    # Attach domain/task metadata
    core_prompt = core_meta.drop_duplicates(subset=["prompt_id"])[["prompt_id", "domain", "task"]]
    met = met.merge(core_prompt, on="prompt_id", how="left")
    return met


def aggregate_slices(met_prompt: pd.DataFrame, out_root: str):
    """
    Aggregates per-prompt metrics by domain and by (domain, task), then writes CSVs.
    """
    keep_cols = [
        "RSI_raw", "BW_raw",
        "RSI_z", "BW_z",
        "RSI_q", "BW_q",
    ]

    # Aggregate by domain
    g1 = met_prompt.groupby(["model", "domain"]).agg(
        num_prompts=("prompt_id", "nunique"),
        **{f"{c}_mean": (c, "mean") for c in keep_cols},
        **{f"{c}_median": (c, "median") for c in keep_cols},
        **{f"{c}_p10": (c, lambda s: float(np.quantile(s, 0.10))) for c in keep_cols},
        **{f"{c}_p90": (c, lambda s: float(np.quantile(s, 0.90))) for c in keep_cols},
    ).reset_index()

    # Aggregate by (domain, task)
    g2 = met_prompt.groupby(["model", "domain", "task"]).agg(
        num_prompts=("prompt_id", "nunique"),
        **{f"{c}_mean": (c, "mean") for c in keep_cols},
        **{f"{c}_median": (c, "median") for c in keep_cols},
        **{f"{c}_p10": (c, lambda s: float(np.quantile(s, 0.10))) for c in keep_cols},
        **{f"{c}_p90": (c, lambda s: float(np.quantile(s, 0.90))) for c in keep_cols},
    ).reset_index()

    g1.to_csv(os.path.join(out_root, "core_metrics_by_domain.csv"), index=False)
    g2.to_csv(os.path.join(out_root, "core_metrics_by_group.csv"), index=False)
    return g1, g2


def compute_consensus(df_scores: pd.DataFrame, out_root: str) -> pd.DataFrame:
    """
    Computes cross-model ranking consensus per prompt.
    For each prompt, it calculates Kendall's tau and Spearman's rho for each pair of models.
    Returns a long-form DataFrame with pairwise correlation details.
    """
    def rank_desc(x):
        # Produces ranks where the highest score gets rank 1.
        return (-x).rank(method="average")

    rank_df = df_scores.copy()
    rank_df["rank"] = rank_df.groupby(["model", "prompt_id"])["score"].transform(rank_desc)

    # Pivot to a dictionary: prompt -> {model -> {cand_index -> rank}}
    per_prompt = defaultdict(dict)
    for (m, pid), sub in rank_df.groupby(["model", "prompt_id"]):
        rmap = dict(zip(sub["cand_index"].tolist(), sub["rank"].tolist()))
        per_prompt[pid][m] = rmap

    def get_model_pairs(ms: List[str]) -> List[Tuple[str, str]]:
        return [(ms[i], ms[j]) for i in range(len(ms)) for j in range(i + 1, len(ms))]

    rows = []
    for pid, model_ranks in per_prompt.items():
        models = sorted(model_ranks.keys())
        if len(models) < 2:
            continue

        for m1, m2 in get_model_pairs(models):
            r1, r2 = model_ranks[m1], model_ranks[m2]
            common_cands = sorted(set(r1.keys()) & set(r2.keys()))
            if len(common_cands) < 3:
                continue

            v1 = np.array([r1[c] for c in common_cands], dtype=float)
            v2 = np.array([r2[c] for c in common_cands], dtype=float)

            tau, rho = None, None
            if HAVE_SCIPY:
                try:
                    tau = float(kendalltau(v1, v2, nan_policy="omit").correlation)
                    rho = float(spearmanr(v1, v2).correlation)
                except Exception: # Fallback for safety
                    v1r = pd.Series(v1).rank(method="average").to_numpy()
                    v2r = pd.Series(v2).rank(method="average").to_numpy()
                    rho = float(np.corrcoef(v1r, v2r)[0, 1])
            else: # No scipy installed
                v1r = pd.Series(v1).rank(method="average").to_numpy()
                v2r = pd.Series(v2).rank(method="average").to_numpy()
                rho = float(np.corrcoef(v1r, v2r)[0, 1])

            rows.append((pid, m1, m2, tau, rho, len(common_cands)))

    if not rows:
        out = pd.DataFrame(columns=["prompt_id", "model_a", "model_b", "kendall_tau", "spearman_rho", "n_common"])
    else:
        out = pd.DataFrame(rows, columns=["prompt_id", "model_a", "model_b", "kendall_tau", "spearman_rho", "n_common"])

    # Aggregate per-prompt averages and save
    if len(out) > 0:
        agg = out.groupby("prompt_id").agg(
            pairs=("model_a", "size"),
            kendall_tau_mean=("kendall_tau", "mean"),
            spearman_rho_mean=("spearman_rho", "mean"),
        ).reset_index()
        agg.to_csv(os.path.join(out_root, "core_consensus_by_prompt.csv"), index=False)
    else: # Create empty file if no data
        pd.DataFrame(columns=["prompt_id", "pairs", "kendall_tau_mean", "spearman_rho_mean"]) \
          .to_csv(os.path.join(out_root, "core_consensus_by_prompt.csv"), index=False)

    return out


def model_summary(met_prompt: pd.DataFrame, out_root: str):
    """
    Creates a global summary per model (means/medians of RSI/BW metrics).
    """
    keep_cols = ["RSI_raw", "BW_raw", "RSI_z", "BW_z", "RSI_q", "BW_q"]
    agg = met_prompt.groupby("model").agg(
        num_prompts=("prompt_id", "nunique"),
        **{f"{c}_mean": (c, "mean") for c in keep_cols},
        **{f"{c}_median": (c, "median") for c in keep_cols},
        **{f"{c}_p90": (c, lambda s: float(np.quantile(s, 0.90))) for c in keep_cols},
    ).reset_index()
    agg.to_csv(os.path.join(out_root, "core_model_summary.csv"), index=False)
    return agg


def main():
    core_path = CONFIG["CORE_DATA"]
    out_root = os.path.abspath(CONFIG["OUT_ROOT"])
    rb_models = CONFIG["RB_MODELS"]
    batch = CONFIG["RB_BATCH_SIZE"]
    tmpl = CONFIG["RB_CHAT_TEMPLATE"]
    entry = CONFIG["RB_ENTRY"]
    extra = CONFIG["RB_EXTRA_FLAGS"]
    skip_done = CONFIG["SKIP_IF_DONE"]

    ensure_dir(out_root)
    log(f"Core data path: {core_path}")
    log(f"Output root directory: {out_root}")
    log(f"Models to evaluate: {rb_models}")

    # 1. Load core metadata (domain/task per prompt)
    core_meta = load_core(core_path)
    log(f"Core metadata loaded: {len(core_meta)} candidate rows across {core_meta['prompt_id'].nunique()} prompts.")

    # 2. Run RewardBench for each model (or skip if outputs exist)
    model_out_dirs = {
        m: run_rewardbench(core_path, m, out_root, batch, tmpl, entry, extra, skip_if_done=skip_done)
        for m in rb_models
    }

    # 3. Load scores from all models
    all_scores_dfs = []
    for m, od in model_out_dirs.items():
        dfm = load_rb_scores(od, m)
        if dfm.empty:
            log(f"[WARN] Empty scores for {m}, will be excluded from metrics.")
        all_scores_dfs.append(dfm)
    
    scores = pd.concat([df for df in all_scores_dfs if not df.empty], ignore_index=True) if any(not df.empty for df in all_scores_dfs) else pd.DataFrame(columns=["model", "prompt_id", "cand_index", "score"])

    if scores.empty:
        log("[FATAL] No scores were loaded. Please check RewardBench output files.")
        sys.exit(1)

    # Attach domain/task metadata
    scores = scores.merge(core_meta, on=["prompt_id", "cand_index"], how="left")

    # Optionally save the merged raw scores
    if CONFIG.get("SAVE_MERGED_PARQUET", True):
        scores.to_parquet(os.path.join(out_root, "all_scores_merged.parquet"), index=False)

    # 4. Add relative columns (z-score, quantile)
    df_rel = add_relative_columns(scores)

    # 5. Calculate per-prompt metrics for each RM
    met_prompt = per_prompt_metrics(df_rel, core_meta)
    met_prompt.to_csv(os.path.join(out_root, "core_metrics_prompts.csv"), index=False)

    # 6. Aggregate metrics by domain and by (domain, task)
    aggregate_slices(met_prompt, out_root)

    # 7. Compute cross-model consensus per prompt
    compute_consensus(scores, out_root)

    # 8. Create a final summary for each model
    model_summary(met_prompt, out_root)

    # 9. Print a brief final report
    log("=== DONE ===")
    log(f"- Prompts: {core_meta['prompt_id'].nunique()} | Models (with scores): {scores['model'].nunique()}")
    log(f"- Total rows scored: {len(scores)}")
    log(f"- Output directory: {out_root}")
    log("Key output files:")
    log("  - core_model_summary.csv")
    log("  - core_metrics_by_domain.csv")
    log("  - core_metrics_by_group.csv")
    log("  - core_metrics_prompts.csv")
    log("  - core_consensus_by_prompt.csv")
    if CONFIG.get("SAVE_MERGED_PARQUET", True):
        log("  - all_scores_merged.parquet")


if __name__ == "__main__":
    main()