import argparse
import os
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from scipy.stats import spearmanr, kendalltau
from sklearn.metrics import average_precision_score, roc_auc_score, precision_recall_curve, roc_curve

# Optional plotting dependency
try:
    import matplotlib.pyplot as plt
    import seaborn as sns
except ImportError:
    plt = None
    sns = None


# --- Metric and Predictor Configuration ---

# Higher score should correspond to higher risk
PREDICTOR_CONFIG = {
    "trace_score": {"sign": 1, "label": "trace"},
    "trace_w1_avgpool": {"sign": 1, "label": "TRACE W1 penult"},
    "trace_w1_layer2": {"sign": 1, "label": "TRACE W1 L2"},
    "trace_w1_layer3": {"sign": 1, "label": "TRACE W1 L3"},
    "trace_mmd_avgpool": {"sign": 1, "label": "TRACE MMD"},
    "trace_energy_avgpool": {"sign": 1, "label": "TRACE Energy"},
    "trace_sw2_avgpool": {"sign": 1, "label": "TRACE SW2"},
    "msp_score": {"sign": -1, "label": "-MSP"},
    "energy_score": {"sign": 1, "label": "Energy"},
    "mmd_score": {"sign": 1, "label": "MMD"},
    "entropy": {"sign": 1, "label": "Entropy"},
    "mahalanobis": {"sign": 1, "label": "Mahalanobis"},
    "kl_disagreement": {"sign": 1, "label": "KL Disagreement"},
    "iw_risk": {"sign": 1, "label": "IW Risk"},
    "domain_classifier_auc": {"sign": 1, "label": "Domain AUC"},
    "a_distance_proxy": {"sign": 1, "label": "A-distance"},
    # Multi-layer and rich metrics if present
    "outdisc_l2_mean": {"sign": 1, "label": "OutDisc L2"},
    "outdisc_l1_mean": {"sign": 1, "label": "OutDisc L1"},
    "outdisc_cosine_mean": {"sign": 1, "label": "OutDisc Cosine"},
    "prob_js": {"sign": 1, "label": "JS"},
    "disagree_rate": {"sign": 1, "label": "Disagree"},
    "maxprob_abs_shift": {"sign": 1, "label": "MaxProbΔ"},
    "margin_mean_shift": {"sign": 1, "label": "MarginΔ"},
    "entropy_abs_shift": {"sign": 1, "label": "EntropyΔ"},
    "ece_shift": {"sign": 1, "label": "ECEΔ"},
    "w1_layer2": {"sign": 1, "label": "W1 L2"},
    "w1_layer3": {"sign": 1, "label": "W1 L3"},
    "w1_avgpool": {"sign": 1, "label": "W1 Penult"},
}

COST_METRICS = ["time_s", "cuda_time_ms", "gpu_mem_mb", "cpu_peak_rss_mb"]


# --- Core Computations ---

def compute_auc_metrics(y_true: np.ndarray, scores: np.ndarray) -> Tuple[float, float]:
    if y_true.sum() in (0, len(y_true)):
        return float("nan"), float("nan")
    return roc_auc_score(y_true, scores), average_precision_score(y_true, scores)


def get_predictor_costs(df: pd.DataFrame) -> Dict[str, Dict[str, float]]:
    costs = {}
    for pred_key in PREDICTOR_CONFIG:
        # Infer name from column, e.g. cost_time_s_trace -> trace
        cost_key_suffix = pred_key.replace("_score", "")
        # Get mean cost across candidates
        cost_cols = [c for c in df.columns if cost_key_suffix in c and "cost_" in c]
        costs[pred_key] = {col.split("_")[1]: df[col].mean() for col in cost_cols}
    return costs


def calculate_tau_sweep(df: pd.DataFrame, predictors: Dict[str, np.ndarray], delta_R: np.ndarray,
                        taus: np.ndarray) -> pd.DataFrame:
    rows = []
    for tau in taus:
        y_true = (delta_R > tau).astype(int)
        row = {"tau": tau, "positive_rate": float(y_true.mean())}
        for name, scores in predictors.items():
            auroc, auprc = compute_auc_metrics(y_true, scores)
            row[f"AUROC_{name}"] = auroc
            row[f"AUPRC_{name}"] = auprc
        rows.append(row)
    return pd.DataFrame(rows)


def calculate_correlations(predictors: Dict[str, np.ndarray], delta_R: np.ndarray) -> pd.DataFrame:
    rows = []
    for name, scores in predictors.items():
        try:
            rho = spearmanr(scores, delta_R).correlation
            kt = kendalltau(scores, delta_R).correlation
        except (ValueError, TypeError):
            rho = float("nan")
            kt = float("nan")
        rows.append({"predictor": name, "spearman_rho": rho, "kendall_tau": kt})
    return pd.DataFrame(rows)


def _bootstrap_ci(metric_fn, y_true: np.ndarray, scores: np.ndarray, n_boot: int = 1000, seed: int = 42) -> Tuple[float, float, float]:
    rng = np.random.RandomState(seed)
    n = len(y_true)
    vals = []
    for _ in range(n_boot):
        idx = rng.randint(0, n, size=n)
        yt = y_true[idx]
        sc = scores[idx]
        try:
            vals.append(metric_fn(yt, sc))
        except Exception:
            continue
    if len(vals) == 0:
        return float("nan"), float("nan"), float("nan")
    vals = np.array(vals, dtype=float)
    return float(np.mean(vals)), float(np.percentile(vals, 2.5)), float(np.percentile(vals, 97.5))


def compute_classification_metrics(df: pd.DataFrame, predictors: Dict[str, np.ndarray], label: np.ndarray,
                                   outdir: str, boot: int = 1000, seed: int = 42) -> pd.DataFrame:
    rows = []
    for name, scores in predictors.items():
        # AUCs with bootstrap CIs
        auroc_mean, auroc_lo, auroc_hi = _bootstrap_ci(roc_auc_score, label, scores, n_boot=boot, seed=seed)
        auprc_mean, auprc_lo, auprc_hi = _bootstrap_ci(average_precision_score, label, scores, n_boot=boot, seed=seed)
        rows.append({
            "predictor": name,
            "AUROC": auroc_mean, "AUROC_lo": auroc_lo, "AUROC_hi": auroc_hi,
            "AUPRC": auprc_mean, "AUPRC_lo": auprc_lo, "AUPRC_hi": auprc_hi,
        })
        # Optional plots
        if plt is not None:
            try:
                fpr, tpr, _ = roc_curve(label, scores)
                prec, rec, _ = precision_recall_curve(label, scores)
                plt.figure(); plt.plot(fpr, tpr); plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title(f"ROC {name}")
                plt.grid(True, ls="--"); plt.tight_layout(); plt.savefig(os.path.join(outdir, f"roc_{name}.png")); plt.close()
                plt.figure(); plt.plot(rec, prec); plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(f"PR {name}")
                plt.grid(True, ls="--"); plt.tight_layout(); plt.savefig(os.path.join(outdir, f"pr_{name}.png")); plt.close()
            except Exception:
                pass
    return pd.DataFrame(rows)


def _expected_utility(y_true: np.ndarray, y_pred: np.ndarray, c_fn: float, c_fp: float) -> float:
    fn = np.logical_and(y_true == 1, y_pred == 0).sum()
    fp = np.logical_and(y_true == 0, y_pred == 1).sum()
    return - (c_fn * fn + c_fp * fp) / float(len(y_true))


def _calibrate_threshold(scores: np.ndarray, y_true: np.ndarray, method: str = "utility",
                         c_fn: float = 5.0, c_fp: float = 1.0, target_pos_rate: float = None) -> float:
    # If targeting a positive rate, choose quantile threshold
    if target_pos_rate is not None:
        q = 1.0 - float(target_pos_rate)
        return float(np.quantile(scores, q))
    # Grid over unique scores
    uniq = np.unique(scores)
    if uniq.size > 512:
        # downsample thresholds for speed
        uniq = np.quantile(scores, np.linspace(0, 1, 512))
    best_thr = uniq[0]
    best_val = -np.inf
    for t in uniq:
        y_pred = (scores >= t).astype(int)
        if method == "utility":
            val = _expected_utility(y_true, y_pred, c_fn=c_fn, c_fp=c_fp)
        elif method == "youden":
            # TPR - FPR
            tp = np.logical_and(y_true == 1, y_pred == 1).sum()
            fn = np.logical_and(y_true == 1, y_pred == 0).sum()
            fp = np.logical_and(y_true == 0, y_pred == 1).sum()
            tn = np.logical_and(y_true == 0, y_pred == 0).sum()
            tpr = tp / max(tp + fn, 1)
            fpr = fp / max(fp + tn, 1)
            val = tpr - fpr
        elif method == "f1":
            tp = np.logical_and(y_true == 1, y_pred == 1).sum()
            fp = np.logical_and(y_true == 0, y_pred == 1).sum()
            fn = np.logical_and(y_true == 1, y_pred == 0).sum()
            prec = tp / max(tp + fp, 1)
            rec = tp / max(tp + fn, 1)
            val = 2 * prec * rec / max(prec + rec, 1e-8)
        else:
            raise ValueError(f"Unknown method {method}")
        if val > best_val:
            best_val = val
            best_thr = t
    return float(best_thr)


def evaluate_threshold_transfer(cal_scores: np.ndarray, cal_labels: np.ndarray,
                                test_scores: np.ndarray, test_labels: np.ndarray,
                                method: str, c_fn: float, c_fp: float,
                                outdir: str, name: str) -> Dict[str, float]:
    thr = _calibrate_threshold(cal_scores, cal_labels, method=method, c_fn=c_fn, c_fp=c_fp)
    y_pred_test = (test_scores >= thr).astype(int)
    acc = (y_pred_test == test_labels).mean()
    tp = np.logical_and(test_labels == 1, y_pred_test == 1).sum()
    fp = np.logical_and(test_labels == 0, y_pred_test == 1).sum()
    fn = np.logical_and(test_labels == 1, y_pred_test == 0).sum()
    tn = np.logical_and(test_labels == 0, y_pred_test == 0).sum()
    prec = tp / max(tp + fp, 1)
    rec = tp / max(tp + fn, 1)
    fpr = fp / max(fp + tn, 1)
    fnr = fn / max(tp + fn, 1)
    util = _expected_utility(test_labels, y_pred_test, c_fn=c_fn, c_fp=c_fp)
    return {
        "threshold": thr, "acc": acc, "precision": prec, "recall": rec, "fpr": fpr, "fnr": fnr, "utility": util,
    }


def compute_ndcg_at_k(scores: np.ndarray, relevance: np.ndarray, k: int) -> float:
    idx = np.argsort(-scores)
    rel_k = relevance[idx][:k]
    dcg = np.sum((2 ** rel_k - 1) / np.log2(np.arange(2, 2 + len(rel_k))))
    # Ideal DCG
    rel_sorted = np.sort(relevance)[::-1][:k]
    idcg = np.sum((2 ** rel_sorted - 1) / np.log2(np.arange(2, 2 + len(rel_sorted))))
    if idcg <= 0:
        return 0.0
    return float(dcg / idcg)


def compute_ndcg_table(predictors: Dict[str, np.ndarray], relevance: np.ndarray, ks: List[int]) -> pd.DataFrame:
    rows = []
    for name, scores in predictors.items():
        row = {"predictor": name}
        for k in ks:
            row[f"NDCG@{k}"] = compute_ndcg_at_k(scores, relevance, k)
        rows.append(row)
    return pd.DataFrame(rows)


def calculate_selection_regret(df: pd.DataFrame, predictors: Dict[str, np.ndarray], k: int = 1) -> pd.DataFrame:
    delta_R = df["delta_R_true"].values
    best_possible_risk = np.min(delta_R)
    rows = []
    for name, scores in predictors.items():
        # Select candidate with highest predicted risk
        selected_idx = np.argsort(-scores)[:k]
        regret = delta_R[selected_idx].mean() - best_possible_risk
        rows.append({"predictor": name, f"regret_at_{k}": regret})
    return pd.DataFrame(rows)


# --- Plotting ---

def plot_pareto(summary_df: pd.DataFrame, cost_metric: str, perf_metric: str, out_path: str):
    if not plt:
        print("Plotting skipped: matplotlib/seaborn not installed.")
        return

    plt.figure(figsize=(10, 8))
    sns.scatterplot(data=summary_df, x=cost_metric, y=perf_metric, hue="predictor", s=200, style="predictor")
    plt.title(f"Performance vs. Cost ({perf_metric} vs {cost_metric})")
    plt.xlabel(f"Cost: {cost_metric}")
    plt.ylabel(f"Performance: {perf_metric}")
    plt.grid(True, which="both", ls="--")
    plt.legend(title="Predictor", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()
    print(f"Saved Pareto plot to {out_path}")


# --- Main Orchestration ---

def main():
    parser = argparse.ArgumentParser(description="Extended analysis for deployment gate results.")
    parser.add_argument("--csv", type=str, required=True, help="Path to deployment_gate_results.csv")
    parser.add_argument("--outdir", type=str, required=True, help="Directory to save analysis outputs")
    parser.add_argument("--tau-min", type=float, default=0.01)
    parser.add_argument("--tau-max", type=float, default=0.35)
    parser.add_argument("--tau-steps", type=int, default=35)
    parser.add_argument("--calib-frac", type=float, default=0.3)
    parser.add_argument("--utility-cfn", type=float, default=5.0)
    parser.add_argument("--utility-cfp", type=float, default=1.0)
    parser.add_argument("--calib-method", type=str, default="utility", choices=["utility","youden","f1"]) 
    args = parser.parse_args()

    os.makedirs(args.outdir, exist_ok=True)
    df = pd.read_csv(args.csv)

    # Prepare predictors and ground truth
    predictors = {
        key: df[key].values * config["sign"]
        for key, config in PREDICTOR_CONFIG.items() if key in df.columns
    }
    delta_R_abs = df["delta_R_true"].abs().values
    delta_R_signed = df["delta_R_signed"].values if "delta_R_signed" in df.columns else None

    # --- Analyses ---
    # 1. Rank correlations
    correlations = calculate_correlations(predictors, delta_R_abs)
    corr_path = os.path.join(args.outdir, "correlations.csv")
    correlations.to_csv(corr_path, index=False)
    print(f"Saved rank correlations to {corr_path}")

    # 2. Tau sweeps for AUROC/AUPRC
    taus = np.linspace(args.tau_min, args.tau_max, args.tau_steps)
    sweep_abs = calculate_tau_sweep(df, predictors, delta_R_abs, taus)
    sweep_abs_path = os.path.join(args.outdir, "tau_sweep_absolute.csv")
    sweep_abs.to_csv(sweep_abs_path, index=False)
    print(f"Saved absolute-label tau sweep to {sweep_abs_path}")

    if delta_R_signed is not None:
        sweep_signed = calculate_tau_sweep(df, predictors, delta_R_signed, taus)
        sweep_signed_path = os.path.join(args.outdir, "tau_sweep_signed.csv")
        sweep_signed.to_csv(sweep_signed_path, index=False)
        print(f"Saved signed-label tau sweep to {sweep_signed_path}")

    # 3. Selection regret
    regret = calculate_selection_regret(df, predictors, k=1)
    regret_path = os.path.join(args.outdir, "selection_regret.csv")
    regret.to_csv(regret_path, index=False)
    print(f"Saved selection regret to {regret_path}")

    # 4. NDCG@k for harmful identification using absolute |ΔR|
    ndcg = compute_ndcg_table(predictors, delta_R_abs, ks=[3,5,10])
    ndcg_path = os.path.join(args.outdir, "ndcg.csv")
    ndcg.to_csv(ndcg_path, index=False)
    print(f"Saved NDCG table to {ndcg_path}")

    # 5. Classification metrics with bootstrap CIs for signed harmfulness (ΔR>0)
    if delta_R_signed is not None:
        harmful = (delta_R_signed > 0).astype(int)
        clf_df = compute_classification_metrics(df, predictors, harmful, outdir=args.outdir, boot=1000)
        clf_path = os.path.join(args.outdir, "classification_metrics.csv")
        clf_df.to_csv(clf_path, index=False)
        print(f"Saved classification metrics with CIs to {clf_path}")

    # 6. Threshold calibration-transfer evaluation on a split (no leakage)
    rng = np.random.RandomState(42)
    n = len(df)
    idx = rng.permutation(n)
    n_cal = int(args.calib_frac * n)
    cal_idx, test_idx = idx[:n_cal], idx[n_cal:]
    if delta_R_signed is not None and n_cal > 0 and len(test_idx) > 0:
        harmful = (delta_R_signed > 0).astype(int)
        cal_label = harmful[cal_idx]
        test_label = harmful[test_idx]
        cal_res = []
        for name, scores in predictors.items():
            cal_scores = scores[cal_idx]
            test_scores = scores[test_idx]
            stats = evaluate_threshold_transfer(
                cal_scores, cal_label, test_scores, test_label,
                method=args.calib_method, c_fn=args.utility_cfn, c_fp=args.utility_cfp,
                outdir=args.outdir, name=name,
            )
            stats.update({"predictor": name})
            cal_res.append(stats)
        cal_df = pd.DataFrame(cal_res)
        cal_path = os.path.join(args.outdir, "threshold_transfer.csv")
        cal_df.to_csv(cal_path, index=False)
        print(f"Saved threshold transfer results to {cal_path}")

    # 7. Cost summary
    costs_by_predictor = get_predictor_costs(df)
    cost_df = pd.DataFrame.from_dict(costs_by_predictor, orient="index").reset_index().rename(columns={"index": "predictor"})
    cost_path = os.path.join(args.outdir, "costs.csv")
    cost_df.to_csv(cost_path, index=False)
    print(f"Saved cost summary to {cost_path}")

    # 8. Combined summary and Pareto plots
    summary_df = correlations.merge(regret, on="predictor")
    summary_df = summary_df.merge(ndcg, on="predictor")
    summary_df = summary_df.merge(cost_df, on="predictor")
    summary_path = os.path.join(args.outdir, "summary.csv")
    summary_df.to_csv(summary_path, index=False)
    print(f"Saved combined summary to {summary_path}")

    for cost_metric in ["time_s", "gpu_mem_mb"]:
        if cost_metric in summary_df.columns:
            plot_path = os.path.join(args.outdir, f"pareto_rho_vs_{cost_metric}.png")
            plot_pareto(summary_df, cost_metric, "spearman_rho", plot_path)


if __name__ == "__main__":
    main()


