import pandas as pd
from pathlib import Path
import argparse
import sys
from sklearn.metrics import roc_auc_score, average_precision_score

def compute_gate_metrics(results_csv: Path, thresholds: list):
    """
    Computes AUROC and AUPRC for different gate scores at various harm thresholds.
    """
    if not results_csv.is_file():
        print(f"Error: Results file not found at {results_csv}")
        sys.exit(1)

    df = pd.read_csv(results_csv)

    gate_scores = {
        "TRACE-W": "trace_w_score",
        "TRACE-MMD": "trace_mmd_score",
        "-MSP": "msp_score"
    }

    output_data = []

    print("--- Deployment Gate Results ---")
    for tau in thresholds:
        # Define 'harmful' based on the threshold
        df['is_harmful'] = (df['delta_R'] > tau).astype(int)
        
        # Skip if no harmful samples exist at this threshold
        if df['is_harmful'].sum() == 0:
            print(f"\nSkipping threshold tau={tau:.2f} (no harmful updates)")
            continue
            
        print(f"\nThreshold tau = {tau:.2f} ({df['is_harmful'].sum()} harmful updates)")
        
        row = {"tau": tau}
        for score_name, col_name in gate_scores.items():
            if col_name not in df.columns:
                continue

            # Ensure there's variance in predictions and labels
            if df[col_name].nunique() < 2 or df['is_harmful'].nunique() < 2:
                auroc = 0.5
                auprc = 0.5
            else:
                auroc = roc_auc_score(df['is_harmful'], df[col_name])
                auprc = average_precision_score(df['is_harmful'], df[col_name])

            row[f"AUROC_{score_name}"] = auroc
            row[f"AUPRC_{score_name}"] = auprc
            print(f"  {score_name}: AUROC = {auroc:.3f}, AUPRC = {auprc:.3f}")

        output_data.append(row)

    results_df = pd.DataFrame(output_data)
    
    # --- Save results to CSV and LaTeX ---
    output_dir = results_csv.parent
    results_df.to_csv(output_dir / "deployment_gate_summary.csv", index=False, float_format='%.3f')
    
    # Generate LaTeX table
    latex_path = output_dir / "deployment_gate_results.tex"
    with open(latex_path, "w") as f:
        f.write("% Generated by post_processing/compute_gate_metrics.py\n")
        # Customize columns for the paper's table style
        latex_df = results_df.rename(columns={
            "tau": r"$\tauval$",
            "AUROC_TRACE-W": r"AUROC$_{\text{TRACE}}$",
            "AUPRC_TRACE-W": r"AUPRC$_{\text{TRACE}}$",
            "AUROC_-MSP": r"AUROC$_{\text{-MSP}}$",
            "AUPRC_-MSP": r"AUPRC$_{\text{-MSP}}$"
        })
        
        # Assuming TRACE-W and TRACE-MMD are similar, we only show one TRACE column
        # and one -MSP column as in the paper
        table_cols = [r"$\tauval$", r"AUROC$_{\text{TRACE}}$", r"AUPRC$_{\text{TRACE}}$", r"AUROC$_{\text{-MSP}}$", r"AUPRC$_{\text{-MSP}}$"]
        if all(c in latex_df.columns for c in table_cols):
             f.write(latex_df[table_cols].to_latex(index=False, float_format='%.3f', escape=False))
        else:
             f.write(latex_df.to_latex(index=False, float_format='%.3f', escape=False))

    print(f"\nSuccessfully wrote summary to:")
    print(f"- CSV: {output_dir / 'deployment_gate_summary.csv'}")
    print(f"- LaTeX: {latex_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compute AUROC/AUPRC for TRACE Deployment Gate experiment.")
    parser.add_argument("results_csv", type=Path, help="Path to the gate_metrics.csv file from the experiment.")
    parser.add_argument("--thresholds", type=float, nargs='+', default=[0.10, 0.13, 0.23], help="List of harm thresholds (tau) to evaluate.")
    args = parser.parse_args()

    compute_gate_metrics(args.results_csv, args.thresholds)
