# -*- coding: utf-8 -*-
"""
rb_variance_pipeline.py
- Metric: DCI (Decision Consistency Index) is used.
- Composite Score: Equal-weighted aggregation of {nGMD_med, SEI_med, DCI}.
- Task Heatmap: Colormap is set to 'viridis' (not zero-centered).
"""

from __future__ import annotations
import os, sys, json, math, logging
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

try:
    from adjustText import adjust_text
    ADJUSTTEXT_AVAILABLE = True
except ImportError:
    ADJUSTTEXT_AVAILABLE = False

PARAMS = {
    # Directory containing the raw model output files (*_outputs.jsonl).
    "ROOT_DIR": "./data/model_outputs",
    # Directory where all generated metrics and figures will be saved.
    "OUTPUT_DIR": "./output",
    "FILE_GLOB": "**/*_outputs.jsonl",

    "MIN_CANDS_PER_PROMPT": 2,

    "IQR_Q_LOW": 25,
    "IQR_Q_HIGH": 75,

    "REF_SCALE_MODE": "mad",
    "FALLBACK_SCALE_STRATEGY": "prompt_iqr_median",
    "SCALE_MIN_ABS": 1e-6,

    "ENABLE_TASK_NORM_VARIANTS": True,
    "WINSOR_PCT": 0.00,
    "NP_QUANTILE_METHOD": "linear",

    # Set of metrics for correlation and outlier detection.
    "CORE_METRICS": ["nGMD_med", "nGap_med", "SEI_med", "DCI", "RSI_IQR_med"],

    # Metrics used to compute the final composite score.
    "MAIN_METRICS": ["nGMD_med", "SEI_med", "DCI"],
    "COMPOSITE_WEIGHTS": {"nGMD_med": 1.0, "SEI_med": 1.0, "DCI": 1.0},

    # Parameters for analysis and plotting.
    "Z_CAP": 6.0,
    "CORR_DROP_THRESHOLD": 0.92,
    "OUTLIER_Z_THR": 8.0,
    "OUTLIER_LOG10_DELTA": 2.0,
    "TOPN": 12,
    "FIGSIZE": (10, 6),
    "LOG_LEVEL": "INFO",
}

logging.basicConfig(
    level=getattr(logging, PARAMS["LOG_LEVEL"].upper(), logging.INFO),
    format="[%(levelname)s] %(message)s"
)
logger = logging.getLogger("rb_variance_pipeline")

# ----------------- Utility Functions -----------------
def make_finite(arr: np.ndarray) -> np.ndarray:
    arr = np.asarray(arr, dtype=float)
    return arr[np.isfinite(arr)]

def np_quantile(x: np.ndarray, q: float, method: str = "linear") -> float:
    x = np.asarray(x)
    qf = q / 100.0 if q > 1 else q
    try:
        return float(np.quantile(x, qf, method=method))
    except TypeError: # For older numpy versions
        return float(np.quantile(x, qf, interpolation="linear"))

def winsorize(values: np.ndarray, p: float, method: str = "linear") -> np.ndarray:
    if p <= 0:
        return values
    lo = np_quantile(values, p, method)
    hi = np_quantile(values, 100 - p, method)
    return np.clip(values, lo, hi)

def median_abs_deviation(x: np.ndarray) -> float:
    med = np.median(x)
    mad = np.median(np.abs(x - med))
    return 1.4826 * mad # Scale factor to make it comparable to standard deviation

def ref_scale_raw(arr: np.ndarray, mode: str) -> float:
    if len(arr) == 0:
        return 0.0
    if mode == "mad":
        return median_abs_deviation(arr)
    elif mode == "std":
        return float(np.std(arr, ddof=1)) if len(arr) > 1 else 0.0
    elif mode == "iqr":
        ql = np_quantile(arr, PARAMS["IQR_Q_LOW"], PARAMS["NP_QUANTILE_METHOD"])
        qh = np_quantile(arr, PARAMS["IQR_Q_HIGH"], PARAMS["NP_QUANTILE_METHOD"])
        return float((qh - ql) / 1.349) if qh >= ql else 0.0
    else:
        raise ValueError("Unknown mode for ref_scale_raw")

def ref_scale_with_fallback(arr: np.ndarray, mode: str, prompt_iqrs: Optional[np.ndarray]) -> float:
    arr = make_finite(arr)
    s = ref_scale_raw(arr, mode)
    if not np.isfinite(s):
        s = 0.0
    if s < PARAMS["SCALE_MIN_ABS"] and PARAMS["FALLBACK_SCALE_STRATEGY"]:
        if PARAMS["FALLBACK_SCALE_STRATEGY"] == "prompt_iqr_median":
            s_alt = 0.0
            if prompt_iqrs is not None and len(prompt_iqrs) > 0:
                s_alt = float(np.median(make_finite(prompt_iqrs)) / 1.349)
        else:
            s_alt = ref_scale_raw(arr, PARAMS["FALLBACK_SCALE_STRATEGY"])
            if not np.isfinite(s_alt):
                s_alt = 0.0
        s = max(s, s_alt)
    if not np.isfinite(s) or s <= 0:
        s = PARAMS["SCALE_MIN_ABS"]
    return s

def gmd(values: np.ndarray) -> float:
    """Calculates the Gini Mean Difference."""
    n = len(values)
    if n < 2: return np.nan
    x = np.sort(values)
    idx = np.arange(1, n + 1)
    coef = (2 * idx - n - 1)
    s = float(np.dot(coef, x))
    return (2.0 / (n * (n - 1))) * s

def softmax(z: np.ndarray) -> np.ndarray:
    z = z - np.max(z) # For numerical stability
    ez = np.exp(z)
    return ez / np.sum(ez)

def entropy(p: np.ndarray) -> float:
    p = p[p > 0]
    return -float(np.sum(p * np.log(p)))

def top2_gap(values: np.ndarray) -> float:
    if len(values) < 2: return np.nan
    sx = np.sort(values)[::-1]
    return float(sx[0] - sx[1])

def mad_z(x: np.ndarray) -> Tuple[np.ndarray, float, float]:
    """Computes Median Absolute Deviation Z-scores."""
    x = np.asarray(x, dtype=float)
    med = np.nanmedian(x)
    mad = np.nanmedian(np.abs(x - med))
    scale = (1.4826 * mad) if mad and mad > 0 else 1.0
    z = (x - med) / scale
    return z, med, mad

def compute_prompt_metrics(scores: np.ndarray,
                           scale_global: float,
                           iqr_q_low: int,
                           iqr_q_high: int,
                           np_quant_method: str,
                           winsor_pct: float) -> Dict[str, float]:
    """Computes all metrics for a single prompt."""
    if len(scores) == 0:
        return {"n_cands": 0, "iqr": np.nan, "RSI_IQR": np.nan, "GMD": np.nan, "nGMD": np.nan,
                "Gap12": np.nan, "nGap": np.nan, "SEI": np.nan}
    x = np.asarray(scores, dtype=float)
    x = x[np.isfinite(x)]
    if winsor_pct > 0 and x.size > 0:
        x = winsorize(x, winsor_pct * 100, np_quant_method)
    if x.size < 2:
        return {"n_cands": x.size, "iqr": np.nan, "RSI_IQR": np.nan, "GMD": np.nan, "nGMD": np.nan,
                "Gap12": np.nan, "nGap": np.nan, "SEI": np.nan}

    ql = np_quantile(x, iqr_q_low, np_quant_method)
    qh = np_quantile(x, iqr_q_high, np_quant_method)
    iqr = float(qh - ql)

    RSI_IQR = iqr / scale_global if scale_global > 0 else np.nan
    GMD = gmd(x)
    nGMD = GMD / scale_global if scale_global > 0 else np.nan

    gap12 = top2_gap(x)
    nGap = gap12 / scale_global if (scale_global > 0 and not np.isnan(gap12)) else np.nan

    tau = (iqr / 1.349) if iqr > 0 else max(scale_global, PARAMS["SCALE_MIN_ABS"])
    p = softmax(x / tau)
    H = entropy(p)
    n = len(x)
    SEI = 1.0 - (H / math.log(n)) if n >= 2 else np.nan

    return {"n_cands": n, "iqr": iqr, "RSI_IQR": RSI_IQR, "GMD": GMD, "nGMD": nGMD,
            "Gap12": gap12, "nGap": nGap, "SEI": SEI}

def summarize_groups(df: pd.DataFrame, by: List[str], metric_cols: List[str]) -> pd.DataFrame:
    """Aggregates metrics by median and IQR for specified groups."""
    aggs = {}
    for m in metric_cols:
        aggs[m + "_med"] = (m, "median")
        aggs[m + "_iqr"] = (m, lambda s: float(np_quantile(s.dropna().values, 75, PARAMS["NP_QUANTILE_METHOD"])
                                               - np_quantile(s.dropna().values, 25, PARAMS["NP_QUANTILE_METHOD"]))
                                      if s.notna().any() else np.nan)
    return df.groupby(by, dropna=False).agg(**aggs).reset_index()

def parse_rm_key(path: Path) -> Tuple[str, str, str]:
    """Parses model identifier information from the file path."""
    try:
        run_dir = path.parents[2].name
        org = path.parents[1].name
        model_name = path.name.replace("_outputs.jsonl", "")
        return f"{org}/{model_name}", run_dir, model_name
    except Exception:
        return path.stem.replace("_outputs", ""), "", path.stem

def collect_model_scores(jsonl_path: Path) -> pd.DataFrame:
    """Loads scores from a JSONL file into a DataFrame."""
    rows = []
    with open(jsonl_path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            line = line.strip()
            if not line: continue
            try:
                obj = json.loads(line)
            except Exception:
                continue
            r = obj.get("results", None)
            pid = str(obj.get("prompt_id") or (obj.get("meta") or {}).get("prompt_id") or "")
            if r is None or pid == "": continue
            domain = obj.get("domain") or (obj.get("meta") or {}).get("domain") or ""
            task = obj.get("task") or (obj.get("meta") or {}).get("task") or ""
            rows.append({"prompt_id": pid, "domain": domain, "task": task, "score": float(r)})
    return pd.DataFrame(rows)

def _series_iqr(s: pd.Series) -> float:
    """Helper to compute IQR for a pandas Series."""
    v = s.dropna().values
    if v.size == 0: return np.nan
    return float(np_quantile(v, 75, PARAMS["NP_QUANTILE_METHOD"]) - np_quantile(v, 25, PARAMS["NP_QUANTILE_METHOD"]))

def compute_dci_table(prompt_df: pd.DataFrame,
                      eps: float = 1e-6,
                      tiny: float = 1e-12) -> pd.DataFrame:
    """
    Computes the Decision Consistency Index (DCI).
      D_ng  = (median_p(nGMD) + eps) / max(IQR_p(nGMD), tiny)
      D_sei = (median_p(SEI) + eps) / max(IQR_p(SEI), tiny)
      DCI   = exp(-2 / (D_ng + D_sei)) in (0, 1]
    """
    keys = ["rm_key", "run_dir", "model_name"]
    rows = []
    for (rm_key, run_dir, model_name), sub in prompt_df.groupby(keys, dropna=False):
        ng = sub["nGMD"]
        sei = sub["SEI"]

        iqr_ng = _series_iqr(ng)
        iqr_sei = _series_iqr(sei)
        med_ng = float(ng.median(skipna=True))
        med_sei = float(sei.median(skipna=True))

        if not np.isfinite(iqr_ng): iqr_ng = np.nan
        if not np.isfinite(iqr_sei): iqr_sei = np.nan
        if not np.isfinite(med_ng): med_ng = np.nan
        if not np.isfinite(med_sei): med_sei = np.nan

        if np.isnan(iqr_ng) or np.isnan(iqr_sei) or np.isnan(med_ng) or np.isnan(med_sei):
            dci = np.nan
        else:
            D_ng = (med_ng + eps) / max(iqr_ng, tiny)
            D_sei = (med_sei + eps) / max(iqr_sei, tiny)
            dci = math.exp(-2.0 / (D_ng + D_sei))
        rows.append({"rm_key": rm_key, "run_dir": run_dir, "model_name": model_name, "DCI": dci})
    return pd.DataFrame(rows)

# ----------------- Stage 1: Metric Computation and Aggregation -----------------
def stage_compute_metrics():
    ROOT = Path(PARAMS["ROOT_DIR"])
    OUTDIR = Path(PARAMS["OUTPUT_DIR"])
    OUTDIR.mkdir(parents=True, exist_ok=True)

    files = sorted(ROOT.glob(PARAMS["FILE_GLOB"]))
    if not files:
        logger.error(f"No *_outputs.jsonl files found under {ROOT}."); sys.exit(1)

    all_prompt_records, rm_scale_table, task_scale_table = [], [], []
    processed = skipped = 0

    for fp in files:
        if not fp.exists() or fp.stat().st_size == 0:
            skipped += 1; continue
        rm_key, run_dir, model_name = parse_rm_key(fp)
        df = collect_model_scores(fp)
        if df.empty:
            skipped += 1; continue

        prompt_iqrs = []
        for _, sub in df.groupby("prompt_id", dropna=False):
            x = make_finite(sub["score"].values)
            if x.size >= 2:
                ql = np_quantile(x, PARAMS["IQR_Q_LOW"], PARAMS["NP_QUANTILE_METHOD"])
                qh = np_quantile(x, PARAMS["IQR_Q_HIGH"], PARAMS["NP_QUANTILE_METHOD"])
                prompt_iqrs.append(max(0.0, float(qh - ql)))
        prompt_iqrs = np.asarray(prompt_iqrs, dtype=float) if len(prompt_iqrs) > 0 else None

        all_scores_raw = df["score"].values
        all_scores = make_finite(all_scores_raw)
        s_global = ref_scale_with_fallback(all_scores, PARAMS["REF_SCALE_MODE"], prompt_iqrs)
        rm_scale_table.append({
            "rm_key": rm_key, "run_dir": run_dir, "model_name": model_name,
            "scale_global": s_global,
            "n_total": len(all_scores_raw),
            "n_finite": len(all_scores),
            "ref_mode": PARAMS["REF_SCALE_MODE"]
        })

        if PARAMS["ENABLE_TASK_NORM_VARIANTS"]:
            for t, sub in df.groupby("task", dropna=False):
                arr = make_finite(sub["score"].values)
                t_prompt_iqrs = []
                for _, subp in sub.groupby("prompt_id", dropna=False):
                    x = make_finite(subp["score"].values)
                    if x.size >= 2:
                        ql = np_quantile(x, PARAMS["IQR_Q_LOW"], PARAMS["NP_QUANTILE_METHOD"])
                        qh = np_quantile(x, PARAMS["IQR_Q_HIGH"], PARAMS["NP_QUANTILE_METHOD"])
                        t_prompt_iqrs.append(max(0.0, float(qh - ql)))
                t_prompt_iqrs = np.asarray(t_prompt_iqrs, dtype=float) if len(t_prompt_iqrs) > 0 else None
                sval = ref_scale_with_fallback(arr, PARAMS["REF_SCALE_MODE"], t_prompt_iqrs)
                task_scale_table.append({"rm_key": rm_key, "task": t, "scale_task": sval})

        for pid, sub in df.groupby("prompt_id", dropna=False):
            scores = sub["score"].values
            domain = sub["domain"].iloc[0] if len(sub) > 0 else ""
            task = sub["task"].iloc[0] if len(sub) > 0 else ""

            m = compute_prompt_metrics(scores, s_global,
                                       PARAMS["IQR_Q_LOW"], PARAMS["IQR_Q_HIGH"],
                                       PARAMS["NP_QUANTILE_METHOD"], PARAMS["WINSOR_PCT"])
            rec = {"rm_key": rm_key, "run_dir": run_dir, "model_name": model_name,
                   "prompt_id": pid, "domain": domain, "task": task,
                   "n_cands": int(m["n_cands"]), "iqr": m["iqr"],
                   "RSI_IQR": m["RSI_IQR"], "GMD": m["GMD"], "nGMD": m["nGMD"],
                   "Gap12": m["Gap12"], "nGap": m["nGap"], "SEI": m["SEI"],
                   "scale_global": s_global}

            if PARAMS["ENABLE_TASK_NORM_VARIANTS"]:
                s_task = next((r["scale_task"] for r in task_scale_table
                               if r["rm_key"] == rm_key and r["task"] == task),
                              PARAMS["SCALE_MIN_ABS"])
                rec["RSI_IQR_tasknorm"] = (m["iqr"] / s_task) if s_task > 0 else np.nan
                rec["nGMD_tasknorm"] = (m["GMD"] / s_task) if s_task > 0 else np.nan
                rec["nGap_tasknorm"] = (m["Gap12"] / s_task) if (s_task > 0 and not np.isnan(m["Gap12"])) else np.nan
                rec["scale_task"] = s_task

            if rec["n_cands"] < PARAMS["MIN_CANDS_PER_PROMPT"]:
                for k in list(rec.keys()):
                    if k in ("rm_key","run_dir","model_name","prompt_id","domain","task","n_cands","scale_global","scale_task"):
                        continue
                    rec[k] = np.nan
            all_prompt_records.append(rec)
        processed += 1

    if not all_prompt_records:
        logger.error("No valid prompt-level records were generated."); sys.exit(2)

    prompt_df = pd.DataFrame(all_prompt_records)
    rm_scale_df = pd.DataFrame(rm_scale_table)
    task_scale_df = pd.DataFrame(task_scale_table) if PARAMS["ENABLE_TASK_NORM_VARIANTS"] else pd.DataFrame(columns=["rm_key","task","scale_task"])

    OUTDIR = Path(PARAMS["OUTPUT_DIR"])
    prompt_df.to_csv(OUTDIR / "prompt_metrics.csv", index=False)
    rm_scale_df.to_csv(OUTDIR / "rm_scales.csv", index=False)
    if not task_scale_df.empty:
        task_scale_df.to_csv(OUTDIR / "rm_task_scales.csv", index=False)

    metric_cols = ["RSI_IQR", "nGMD", "nGap", "SEI"]
    if PARAMS["ENABLE_TASK_NORM_VARIANTS"]:
        metric_cols += ["RSI_IQR_tasknorm", "nGMD_tasknorm", "nGap_tasknorm"]

    rm_global = summarize_groups(prompt_df, by=["rm_key", "run_dir", "model_name"], metric_cols=metric_cols)
    dci_df = compute_dci_table(prompt_df)
    rm_global = rm_global.merge(dci_df, on=["rm_key", "run_dir", "model_name"], how="left")
    rm_global.to_csv(OUTDIR / "rm_global_metrics.csv", index=False)

    rm_by_task = summarize_groups(prompt_df, by=["rm_key", "run_dir", "model_name", "task"], metric_cols=metric_cols)
    rm_by_task.to_csv(OUTDIR / "rm_by_task_metrics.csv", index=False)

    rm_by_domain = summarize_groups(prompt_df, by=["rm_key", "run_dir", "model_name", "domain"], metric_cols=metric_cols)
    rm_by_domain.to_csv(OUTDIR / "rm_by_domain_metrics.csv", index=False)

    logger.info(f"[Stage 1 Complete] Processed {processed} RMs, skipped {skipped}. Output directory: {OUTDIR.resolve()}")

# ----------------- Stage 2: Curation and Visualization -----------------
def detect_outliers_block(df: pd.DataFrame, metric_cols: list) -> pd.Series:
    mask = pd.Series(False, index=df.index)
    for m in metric_cols:
        v = df[m].astype(float)
        med = v.median(skipna=True)
        mad = (v - med).abs().median(skipna=True)
        scale = 1.4826 * mad if (mad and mad > 0) else 1.0
        z = (v - med) / scale
        mask_z = z.abs() > PARAMS["OUTLIER_Z_THR"]
        with np.errstate(divide='ignore', invalid='ignore'):
            logv = np.log10(np.clip(v, 1e-12, None))
        log_med = np.nanmedian(logv)
        mask_log = (np.abs(logv - log_med) > PARAMS["OUTLIER_LOG10_DELTA"])
        mask = mask | mask_z | mask_log
    return mask

def mad_z_series(s: pd.Series, cap: float) -> pd.Series:
    med = s.median(skipna=True)
    mad = (s - med).abs().median(skipna=True)
    scale = 1.4826 * mad if (mad and mad > 0) else 1.0
    z = (s - med) / scale
    z = z.clip(lower=-cap, upper=cap)
    return z.fillna(0.0)

def weighted_mean(columns: List[pd.Series], weights: List[float]) -> pd.Series:
    W = np.asarray(weights, dtype=float)
    W = W / (W.sum() if W.sum() != 0 else 1.0)
    M = pd.concat(columns, axis=1).fillna(0.0).values
    return pd.Series(M.dot(W), index=columns[0].index)

def stage_curate():
    base = Path(PARAMS["OUTPUT_DIR"])
    out = base / "curated"; out.mkdir(parents=True, exist_ok=True)

    rm_global = pd.read_csv(base / "rm_global_metrics.csv")
    rm_task_path = base / "rm_by_task_metrics.csv"
    prompt_path = base / "prompt_metrics.csv"
    rm_task = pd.read_csv(rm_task_path) if rm_task_path.exists() else pd.DataFrame()
    prompt_df = pd.read_csv(prompt_path) if prompt_path.exists() else pd.DataFrame()

    if rm_global.empty:
        logger.error("rm_global_metrics.csv is missing or empty."); return

    id_cols = ["rm_key", "run_dir", "model_name"]
    core = [m for m in PARAMS["CORE_METRICS"] if m in rm_global.columns]
    g = rm_global[id_cols + core].copy()

    if not prompt_df.empty and {"rm_key", "SEI"}.issubset(prompt_df.columns):
        cov = (prompt_df.assign(valid=lambda d: d["SEI"].notna())
               .groupby("rm_key", as_index=False)
               .agg(total_prompts=("SEI", "size"),
                    valid_prompts=("valid", "sum")))
        cov["coverage"] = cov["valid_prompts"] / cov["total_prompts"]
        g = g.merge(cov[["rm_key", "coverage", "valid_prompts", "total_prompts"]], on="rm_key", how="left")

    if not rm_task.empty:
        tcols = [c for c in PARAMS["CORE_METRICS"] if c in rm_task.columns and c != "DCI"]
        rm_task["is_outlier"] = False
        for task, sub in rm_task.groupby("task"):
            if tcols:
                mask = detect_outliers_block(sub, tcols)
                rm_task.loc[sub.index, "is_outlier"] = mask
        rm_task[rm_task["is_outlier"]].sort_values(["task", "model_name"]).to_csv(out / "metric_outliers.csv", index=False)
    else:
        g["is_outlier"] = detect_outliers_block(g, core)
        g[g["is_outlier"]].to_csv(out / "metric_outliers.csv", index=False)

    g_clean = g.copy()
    if "is_outlier" in g_clean.columns:
        g_clean = g_clean[~g_clean["is_outlier"]]
    corr_global = g_clean[[c for c in core if c in g_clean.columns]].corr(method="spearman", min_periods=1)
    corr_global.to_csv(out / "spearman_global.csv")

    redundant_notes = []
    cols = [c for c in core if c in corr_global.columns]
    for i in range(len(cols)):
        for j in range(i + 1, len(cols)):
            rho = corr_global.loc[cols[i], cols[j]]
            if pd.notna(rho) and abs(rho) >= PARAMS["CORR_DROP_THRESHOLD"]:
                redundant_notes.append(f"{cols[i]} and {cols[j]} are highly correlated (|rho|={abs(rho):.2f}). Consider keeping only one.")
    with open(out / "redundancy_suggestion.txt", "w", encoding="utf-8") as f:
        if redundant_notes:
            for line in redundant_notes: f.write(f"- {line}\n")
        else:
            f.write("- No strong redundancies found (correlation <= threshold).\n")

    main_metrics = list(PARAMS["MAIN_METRICS"])
    z_cols = {}
    for m in main_metrics:
        if m in g.columns:
            z = mad_z_series(g[m], cap=PARAMS["Z_CAP"])
        else:
            z = pd.Series(0.0, index=g.index)
        z_cols[m] = z
    weights = [PARAMS["COMPOSITE_WEIGHTS"].get(m, 1.0) for m in main_metrics]
    comp = weighted_mean([z_cols[m] for m in main_metrics], weights)
    g["Composite"] = comp

    for m in main_metrics:
        if m in g.columns:
            g[m + "_rank"] = g[m].rank(method="average", ascending=False)
    g["Composite_rank"] = g["Composite"].rank(method="average", ascending=False)

    g.sort_values("Composite", ascending=False).to_csv(out / "curated_rm_summary.csv", index=False)

    # Top-N by Composite Score
    topN = PARAMS["TOPN"]
    gg = g.sort_values("Composite", ascending=False).head(topN)
    plt.figure(figsize=PARAMS["FIGSIZE"])
    plt.barh(gg["model_name"], gg["Composite"]); plt.gca().invert_yaxis()
    plt.xlabel("Composite (MAD-z weighted)"); plt.ylabel("Model")
    plt.title(f"RM Overall Ranking (Top {topN})"); plt.tight_layout()
    plt.savefig(out / "fig_overall_ranking.png", dpi=200); plt.close()

    # Scatter plot: x=SEI_med, y=nGMD_med, size=DCI, color=Composite
    if {"SEI_med", "nGMD_med", "DCI", "Composite"}.issubset(g.columns):
        plt.style.use('seaborn-v0_8-whitegrid')
        fig, ax = plt.subplots(figsize=(16, 6))

        min_size, max_size = 40, 600
        dci_scores = g["DCI"].dropna()
        if not dci_scores.empty:
            s_norm = (g["DCI"] - dci_scores.min()) / (dci_scores.max() - dci_scores.min() + 1e-9)
            sizes = min_size + (max_size - min_size) * (s_norm ** 1.5)
            sizes = sizes.fillna(min_size)
        else:
            sizes = pd.Series(min_size, index=g.index)

        cmap = plt.get_cmap('plasma')
        scatter = ax.scatter(
            g["SEI_med"], g["nGMD_med"],
            s=sizes, c=g["Composite"], cmap=cmap,
            alpha=0.85, edgecolors='black', linewidth=0.7
        )
        cbar = fig.colorbar(scatter, ax=ax, pad=0.015, aspect=35)
        cbar.set_label('Composite Score (Brighter is Better)', rotation=270, labelpad=22, fontsize=12, weight='bold')
        comp_scores = g["Composite"].dropna()
        if not comp_scores.empty:
            comp_min, comp_max = comp_scores.min(), comp_scores.max()
            ticks = np.linspace(comp_min, comp_max, num=5)
            cbar.set_ticks(ticks); cbar.set_ticklabels([f"{t:.2f}" for t in ticks])
        cbar.ax.tick_params(labelsize=10)

        if ADJUSTTEXT_AVAILABLE:
            texts = []
            for _, row in g.iterrows():
                if pd.notna(row["SEI_med"]) and pd.notna(row["nGMD_med"]):
                    texts.append(ax.text(row["SEI_med"], row["nGMD_med"], str(row.get("model_name", "")), fontsize=8.5))
            if texts:
                adjust_text(texts, arrowprops=dict(arrowstyle='-', color='grey', lw=0.5, alpha=0.8),
                            force_text=(0.15, 0.3), force_points=(0.1, 0.2))
        else: # Fallback if adjustText is not available
            for _, row in g.iterrows():
                if pd.notna(row["SEI_med"]) and pd.notna(row["nGMD_med"]):
                    ax.annotate(str(row.get("model_name", "")), (row["SEI_med"], row["nGMD_med"]),
                                fontsize=8, xytext=(3, 3), textcoords="offset points")

        p_small = plt.scatter([],[], s=min_size + (max_size - min_size) * (0.1 ** 1.5), c='#666666', alpha=0.8, edgecolors='black', linewidth=0.6)
        p_medium = plt.scatter([],[], s=min_size + (max_size - min_size) * (0.5 ** 1.5), c='#666666', alpha=0.8, edgecolors='black', linewidth=0.6)
        p_large = plt.scatter([],[], s=min_size + (max_size - min_size) * (0.9 ** 1.5), c='#666666', alpha=0.8, edgecolors='black', linewidth=0.6)
        legend = ax.legend([p_large, p_medium, p_small], ["High", "Medium", "Low"],
                           title="Consistency (DCI)", scatterpoints=1, loc='center right',
                           frameon=True, edgecolor='gray', fontsize=11, labelspacing=1.2, handletextpad=1.0)
        legend.get_title().set_fontsize('13'); legend.get_title().set_fontweight('bold')

        ax.axvline(g["SEI_med"].median(), color='grey', linestyle='--', linewidth=1.0, alpha=0.7)
        ax.axhline(g["nGMD_med"].median(), color='grey', linestyle='--', linewidth=1.0, alpha=0.7)

        ax.set_xlabel("SEI_med (Confidence Concentration ↑)", fontsize=13, weight='bold', labelpad=10)
        ax.set_ylabel("nGMD_med (Overall Separation ↑)", fontsize=13, weight='bold', labelpad=10)
        ax.set_title("RM Variance Profile: Separation (nGMD) vs. Concentration (SEI)", fontsize=16, weight='bold', pad=15)
        ax.tick_params(axis='both', which='major', labelsize=11)
        ax.grid(True, which='both', linestyle='--', linewidth=0.5)

        # Upper-right quadrant annotation
        ax.text(0.985, 0.96,
                'Ideal Quadrant\n(Decisive & Brave)',
                transform=ax.transAxes, ha='right', va='top',
                fontsize=12, style='italic', color='white',
                bbox=dict(boxstyle='round,pad=0.3', fc='green', ec='black', alpha=0.7))

        # Lower-left quadrant annotation
        ax.text(0.02, 0.05,
                'Indecisive & Cautious',
                transform=ax.transAxes, ha='left', va='bottom',
                fontsize=12, style='italic', color='white',
                bbox=dict(boxstyle='round,pad=0.3', fc='darkred', ec='black', alpha=0.7))

        plt.tight_layout(pad=1.0)
        plt.savefig(out / "fig_scatter_sei_vs_ngmd.png", dpi=300, bbox_inches='tight')
        plt.close()

    # Task Heatmap (viridis)
    if not rm_task.empty:
        task_cols = [c for c in PARAMS["MAIN_METRICS"] if c in rm_task.columns and c != "DCI"]
        tdf = rm_task[["rm_key", "model_name", "task"] + task_cols].copy()

        def _task_comp(sub: pd.DataFrame) -> pd.Series:
            parts, ws = [], []
            for m in task_cols:
                parts.append(mad_z_series(sub[m], cap=PARAMS["Z_CAP"]))
                ws.append(PARAMS["COMPOSITE_WEIGHTS"].get(m, 1.0))
            return weighted_mean(parts, ws)

        tdf["TaskComposite"] = tdf.groupby("task", group_keys=False).apply(_task_comp)
        tdf.sort_values(["rm_key", "TaskComposite"], ascending=[True, False]).to_csv(out / "curated_rm_by_task_scores.csv", index=False)

        pivot = tdf.pivot_table(index="model_name", columns="task", values="TaskComposite", aggfunc="mean")
        order = g.sort_values("Composite", ascending=False)["model_name"].tolist()
        pivot = pivot.reindex(order)

        # Use viridis colormap, not zero-centered.
        plt.figure(figsize=(max(8, 0.5 * (len(pivot.columns) + 6)), max(6, 0.35 * (len(pivot.index) + 8))))
        im = plt.imshow(pivot.values, aspect="auto", cmap='viridis')
        plt.xticks(ticks=np.arange(len(pivot.columns)), labels=pivot.columns, rotation=45, ha="right")
        plt.yticks(ticks=np.arange(len(pivot.index)), labels=pivot.index)
        cbar = plt.colorbar(im); cbar.set_label('TaskComposite (MAD-z)')
        plt.title("Task-wise Composite (MAD-z weighted)")
        plt.tight_layout(); plt.savefig(out / "fig_task_heatmap.png", dpi=220); plt.close()

        for task, sub in tdf.groupby("task"):
            gg = sub.sort_values("TaskComposite", ascending=False).head(PARAMS["TOPN"])
            plt.figure(figsize=(max(8, 0.6 * len(gg)), 6))
            plt.barh(gg["model_name"], gg["TaskComposite"]); plt.gca().invert_yaxis()
            plt.xlabel("TaskComposite (MAD-z)"); plt.title(f"Top {PARAMS['TOPN']} — {task}")
            plt.tight_layout(); plt.savefig(out / f"fig_top_{task}.png", dpi=200); plt.close()

    # Correlation heatmap (horizontal, flat labels, black text, no grid)
    corr = g_clean[[c for c in core if c in g_clean.columns]].corr(method="spearman", min_periods=1)
    cols = corr.columns.tolist()

    fig, ax = plt.subplots(figsize=(14, 3), constrained_layout=True)
    im = ax.imshow(corr.values, vmin=-1, vmax=1, cmap='coolwarm',
                   interpolation='nearest', aspect='auto')

    ax.grid(False) # Disable grid lines
    ax.set_xticks(range(len(cols)))
    ax.set_xticklabels(cols, rotation=0, ha="center", fontsize=10)
    ax.set_yticks(range(len(cols)))
    ax.set_yticklabels(cols, fontsize=10)

    # Use black for all text annotations for better readability.
    for i in range(len(cols)):
        for j in range(len(cols)):
            ax.text(j, i, f"{corr.values[i, j]:.2f}", ha="center", va="center", fontsize=9, color="black")

    cbar = fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02)
    cbar.ax.tick_params(labelsize=9)
    ax.set_title("Spearman Correlation (RM-level core metrics, outliers removed)", fontsize=12, pad=6)

    fig.savefig(out / "fig_metric_correlation.png", dpi=220, bbox_inches='tight', pad_inches=0.02)
    plt.close(fig)

    # Top-N for each individual metric
    for m in set(PARAMS["MAIN_METRICS"]):
        if m in g.columns:
            gg = g.sort_values(m, ascending=False).head(PARAMS["TOPN"])
            plt.figure(figsize=PARAMS["FIGSIZE"])
            plt.barh(gg["model_name"], gg[m]); plt.gca().invert_yaxis()
            plt.xlabel(m); plt.ylabel("Model"); plt.title(f"Top {PARAMS['TOPN']} by {m}")
            plt.tight_layout(); plt.savefig(out / f"fig_topN_each_metric_{m}.png", dpi=200); plt.close()

    # Quick view for outliers
    base_x = np.arange(len(g))
    plt.figure(figsize=(10, 6))
    if "is_outlier" in g.columns:
        normal = ~g["is_outlier"].fillna(False)
        plt.scatter(base_x[normal.values], g.loc[normal, "nGMD_med"], s=10, label="normal")
        plt.scatter(base_x[~normal.values], g.loc[~normal.values, "nGMD_med"], s=16, label="outlier")
    else:
        plt.scatter(base_x, g["nGMD_med"], s=12)
    plt.legend(); plt.ylabel("nGMD_med"); plt.title("Outlier quick view (nGMD_med)")
    plt.tight_layout(); plt.savefig(out / "fig_outlier_quickview.png", dpi=180); plt.close()

    logger.info(f"[Stage 2 Complete] Curation and visualization outputs saved to: {out.resolve()}")

def main():
    stage_compute_metrics()
    stage_curate()

if __name__ == "__main__":
    main()