#!/usr/bin/env python3
"""Reproduce the master_keys.ipynb analysis from the command line.

This script loads the filtered master-key judge outputs, performs model filtering
based on combined invalid ratios (NaN raw output or parsed_output == -1), tunes
CARESL gamma by minimising false positive rate, and reports summary metrics for
multiple aggregators along with per-attack diagnostics.
"""

from __future__ import annotations

import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable

import numpy as np
import pandas as pd

# Ensure project modules resolve when the script is invoked directly.
SCRIPT_DIR = Path(__file__).resolve().parent
REPO_ROOT = SCRIPT_DIR.parent
SRC_DIR = REPO_ROOT / "src"
if str(SRC_DIR) not in map(str, list(Path().resolve().parents)):
    import sys

    sys.path.append(str(SRC_DIR))

from eval_tools import evaluate_multisubject_rlvr, map_yes_no_mixed  # type: ignore  # noqa: E402
from pgm_tools import (  # type: ignore  # noqa: E402
    caret_aggregate,
    caresl_aggregate,
    ws_aggregate,
)

BASE_COLUMNS = ["question", "response", "reference", "subject", "subset", "attack_id"]
DEFAULT_GAMMA_GRID = [0.05, 0.1, 0.2, 0.5, 1.0, 2.0]
DEFAULT_ATTACKS = [
    "colon",
    "comma",
    "period",
    "solution",
    "space",
    "step_by_step",
    "thought_process",
]


@dataclass
class PredictionMetrics:
    judge_accuracy: float
    false_positive_rate: float


def compute_valid_models(directory: Path, threshold: float) -> list[str]:
    """Return model names with combined invalid ratio <= threshold."""
    valid_models: list[str] = []
    for csv_path in sorted(directory.glob("*.csv")):
        df = pd.read_csv(csv_path)
        if df.empty:
            continue
        raw_na = df["raw_output"].isna() if "raw_output" in df else pd.Series(False, index=df.index)
        if "parsed_output" in df:
            parsed = df["parsed_output"].apply(map_yes_no_mixed)
            parsed_invalid = parsed == -1
        else:
            parsed_invalid = pd.Series(False, index=df.index)
        combined = raw_na | parsed_invalid
        if combined.mean() <= threshold:
            valid_models.append(csv_path.stem)
    return valid_models


def load_masterkey_votes(directory: Path, models: Iterable[str]) -> pd.DataFrame:
    """Load vote matrix and metadata for the requested models."""
    base_df: pd.DataFrame | None = None
    vote_data: dict[str, pd.Series] = {}
    for model in models:
        csv_path = directory / f"{model}.csv"
        if not csv_path.exists():
            continue
        df_model = pd.read_csv(csv_path)
        df_model = df_model[~df_model["attack_id"].isna()].copy()
        if df_model.empty:
            continue
        df_model["parsed_output"] = df_model["parsed_output"].apply(map_yes_no_mixed)
        df_model.reset_index(drop=True, inplace=True)
        if base_df is None:
            cols = [c for c in BASE_COLUMNS if c in df_model.columns]
            base_df = df_model[cols].copy()
        vote_data[model] = df_model["parsed_output"]
    if base_df is None or not vote_data:
        return pd.DataFrame(columns=BASE_COLUMNS + list(models))
    votes_frame = pd.DataFrame(vote_data)
    combined = pd.concat([base_df.reset_index(drop=True), votes_frame], axis=1)
    mask = (~combined[list(vote_data)].isna().any(axis=1)) & ((combined[list(vote_data)] != -1).all(axis=1))
    return combined.loc[mask].reset_index(drop=True)


def evaluate_predictions(subset: pd.DataFrame, preds: pd.Series | np.ndarray) -> PredictionMetrics:
    """Compute judge accuracy and false positive rate for predictions."""
    if subset.empty:
        return PredictionMetrics(float("nan"), float("nan"))
    if isinstance(preds, pd.Series):
        pred_series = preds.loc[subset.index].astype(int)
    else:
        pred_series = pd.Series(preds, index=subset.index).astype(int)
    eval_df = subset[[c for c in BASE_COLUMNS if c in subset.columns]].copy()
    eval_df["parsed_output"] = pred_series
    all_results, acc_dict, _ = evaluate_multisubject_rlvr(eval_df.to_dict(orient="records"))
    negatives = [row for row in all_results if row["response_correct"] == 0]
    false_positive = sum(row["judge_vote"] == 1 for row in negatives)
    fpr = false_positive / len(negatives) if negatives else float("nan")
    accuracy = float(acc_dict.get("judge_accuracy", float("nan")))
    return PredictionMetrics(accuracy, float(fpr))


def run_analysis(
    directory: Path,
    threshold: float,
    gamma_grid: list[float],
    class_balance: float,
    attack_ids: list[str],
) -> tuple[pd.DataFrame, pd.DataFrame, float, list[str]]:
    models = compute_valid_models(directory, threshold)
    data = load_masterkey_votes(directory, models)
    if data.empty:
        raise ValueError("No rows remaining after filtering invalid judge outputs.")

    votes = data[models]
    mv_preds = votes.mode(axis=1)[0].astype(int)
    ws_scores = ws_aggregate(votes)
    ws_preds = pd.Series((np.asarray(ws_scores) >= 0.5).astype(int), index=data.index)
    caret_preds = pd.Series(caret_aggregate(votes.values), index=data.index)

    gamma_records = []
    for gamma in gamma_grid:
        caresl_scores = caresl_aggregate(votes, gamma=gamma)
        caresl_preds_gamma = pd.Series((np.asarray(caresl_scores) >= 0.5).astype(int), index=data.index)
        metrics = evaluate_predictions(data, caresl_preds_gamma)
        gamma_records.append(
            {
                "gamma": gamma,
                "judge_accuracy": metrics.judge_accuracy,
                "false_positive_rate": metrics.false_positive_rate,
            }
        )

    gamma_df = (
        pd.DataFrame(gamma_records)
        .sort_values(["false_positive_rate", "judge_accuracy"], ascending=[True, False])
        .reset_index(drop=True)
    )
    best_gamma = float(gamma_df.iloc[0]["gamma"]) if not gamma_df.empty else gamma_grid[0]

    caresl_scores = caresl_aggregate(votes, gamma=best_gamma)
    caresl_preds = pd.Series((np.asarray(caresl_scores) >= 0.5).astype(int), index=data.index)

    aggregator_metrics = pd.DataFrame(
        [
            {
                "aggregator": "majority_vote",
                "judge_accuracy": (mv := evaluate_predictions(data, mv_preds)).judge_accuracy,
                "false_positive_rate": mv.false_positive_rate,
            },
            {
                "aggregator": "weighted_soft",
                "judge_accuracy": (ws := evaluate_predictions(data, ws_preds)).judge_accuracy,
                "false_positive_rate": ws.false_positive_rate,
            },
            {
                "aggregator": f"caresl_gamma_{best_gamma}",
                "judge_accuracy": (ca := evaluate_predictions(data, caresl_preds)).judge_accuracy,
                "false_positive_rate": ca.false_positive_rate,
            },
            {
                "aggregator": "caret",
                "judge_accuracy": (ct := evaluate_predictions(data, caret_preds)).judge_accuracy,
                "false_positive_rate": ct.false_positive_rate,
            },
        ]
    ).sort_values(["false_positive_rate", "judge_accuracy"], ascending=[True, False])

    # Attack-level breakdown
    data = data.assign(
        mv_pred=mv_preds.astype(int),
        ws_pred=ws_preds.astype(int),
        caresl_pred=caresl_preds.astype(int),
        caret_pred=caret_preds.astype(int),
    )
    attack_records = []
    for attack_id in attack_ids:
        subset = data[data["attack_id"] == attack_id]
        if subset.empty:
            attack_records.append(
                {
                    "attack_id": attack_id,
                    "n_rows": 0,
                    "mv_accuracy": float("nan"),
                    "mv_false_positive_rate": float("nan"),
                    "ws_accuracy": float("nan"),
                    "ws_false_positive_rate": float("nan"),
                    "caresl_accuracy": float("nan"),
                    "caresl_false_positive_rate": float("nan"),
                    "caret_accuracy": float("nan"),
                    "caret_false_positive_rate": float("nan"),
                }
            )
            continue
        mv_metrics = evaluate_predictions(subset, subset["mv_pred"])
        ws_metrics = evaluate_predictions(subset, subset["ws_pred"])
        caresl_metrics = evaluate_predictions(subset, subset["caresl_pred"])
        caret_metrics = evaluate_predictions(subset, subset["caret_pred"])
        attack_records.append(
            {
                "attack_id": attack_id,
                "n_rows": int(len(subset)),
                "mv_accuracy": mv_metrics.judge_accuracy,
                "mv_false_positive_rate": mv_metrics.false_positive_rate,
                "ws_accuracy": ws_metrics.judge_accuracy,
                "ws_false_positive_rate": ws_metrics.false_positive_rate,
                "caresl_accuracy": caresl_metrics.judge_accuracy,
                "caresl_false_positive_rate": caresl_metrics.false_positive_rate,
                "caret_accuracy": caret_metrics.judge_accuracy,
                "caret_false_positive_rate": caret_metrics.false_positive_rate,
            }
        )

    attack_df = pd.DataFrame(attack_records)
    return aggregator_metrics, gamma_df, best_gamma, models


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Master-keys aggregation diagnostics")
    parser.add_argument(
        "--base-path",
        type=Path,
        default=REPO_ROOT / "judge_outputs" / "miscellaneous" / "master_keys",
        help="Directory containing per-model master key CSV files.",
    )
    parser.add_argument(
        "--threshold",
        type=float,
        default=0.1,
        help="Maximum combined invalid ratio allowed for models.",
    )
    parser.add_argument(
        "--gamma-grid",
        type=float,
        nargs="*",
        default=DEFAULT_GAMMA_GRID,
        help="Gamma values to evaluate for CARESL (space separated).",
    )
    parser.add_argument(
        "--class-balance",
        type=float,
        default=50.0,
        help="Class balance percentile used inside CARET (kept for parity with notebook).",
    )
    parser.add_argument(
        "--attacks",
        nargs="*",
        default=DEFAULT_ATTACKS,
        help="Attack IDs to include in the per-attack summary.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    aggregator_df, gamma_df, best_gamma, models = run_analysis(
        directory=args.base_path,
        threshold=args.threshold,
        gamma_grid=list(args.gamma_grid),
        class_balance=args.class_balance,
        attack_ids=list(args.attacks),
    )

    print("Models used:")
    for model in models:
        print(f"  - {model}")
    print()

    print("Aggregator metrics (sorted by FPR, then accuracy):")
    print(aggregator_df.to_string(index=False))
    print()

    print("CARESL gamma sweep:")
    print(gamma_df.to_string(index=False))
    print(f"\nSelected gamma: {best_gamma}")


if __name__ == "__main__":
    main()
