"""Metric computation and sanity checks for BioDimBench."""

from __future__ import annotations

import numpy as np
import pandas as pd

from .utils import INVALID_CANDIDATE_TYPES, METHOD_ORDER, safe_divide


METRIC_COLUMNS = [
    "accuracy",
    "valid_precision",
    "valid_recall",
    "invalid_precision",
    "invalid_recall",
    "invalid_f1",
    "macro_f1",
    "false_accept_rate",
    "false_reject_rate",
]


def compute_aggregate_metrics(results: pd.DataFrame) -> pd.DataFrame:
    """Compute aggregate binary classification metrics for each method."""

    rows: list[dict[str, object]] = []
    for (method, split), group in results.groupby(["method", "split"], sort=False):
        evaluated = group[group["prediction"].notna()].copy()
        metrics = _empty_metrics()
        if not evaluated.empty:
            y_true = evaluated["is_valid"].astype(bool).to_numpy()
            y_pred = evaluated["prediction"].map(_to_bool).to_numpy(dtype=bool)
            metrics = _binary_metrics(y_true, y_pred)

        rows.append(
            {
                "method": method,
                "split": split,
                **metrics,
                "n_evaluated": int(len(evaluated)),
                "total_rows": int(len(group)),
                "coverage": safe_divide(len(evaluated), len(group)),
            }
        )

    frame = pd.DataFrame(rows)
    frame["method_order"] = frame["method"].map({m: i for i, m in enumerate(METHOD_ORDER)}).fillna(99)
    return frame.sort_values(["method_order", "split"]).drop(columns=["method_order"]).reset_index(drop=True)


def compute_error_type_recall(results: pd.DataFrame) -> pd.DataFrame:
    """Compute rejection recall for each corrupted candidate type."""

    rows: list[dict[str, object]] = []
    for (method, split), method_group in results.groupby(["method", "split"], sort=False):
        for candidate_type in INVALID_CANDIDATE_TYPES:
            group = method_group[method_group["candidate_type"] == candidate_type]
            evaluated = group[group["prediction"].notna()].copy()
            recall = np.nan
            if not evaluated.empty:
                predictions = evaluated["prediction"].map(_to_bool).to_numpy(dtype=bool)
                recall = float((~predictions).mean())
            rows.append(
                {
                    "method": method,
                    "split": split,
                    "candidate_type": candidate_type,
                    "recall": recall,
                    "n_evaluated": int(len(evaluated)),
                    "total_rows": int(len(group)),
                    "coverage": safe_divide(len(evaluated), len(group)),
                }
            )
    frame = pd.DataFrame(rows)
    frame["method_order"] = frame["method"].map({m: i for i, m in enumerate(METHOD_ORDER)}).fillna(99)
    return frame.sort_values(["method_order", "candidate_type", "split"]).drop(columns=["method_order"]).reset_index(
        drop=True
    )


def run_sanity_checks(results: pd.DataFrame) -> None:
    """Raise an error if core expected verifier behaviors are violated."""

    all_results = results[results["split"] == "all"].copy()

    numeric_correct = _subset(all_results, "numeric_plus_unit", "correct")
    _require_rate(numeric_correct["prediction"].map(_to_bool).mean(), 0.99, "numeric_plus_unit should accept correct rows")

    unit_wrong = _subset(all_results, "unit_only", "wrong_unit")
    unit_wrong_reject = (~unit_wrong["prediction"].map(_to_bool)).mean()
    _require_rate(unit_wrong_reject, 0.99, "unit_only should reject wrong_unit rows")

    numeric_wrong = _subset(all_results, "numeric_plus_unit", "wrong_unit")
    numeric_wrong_reject = (~numeric_wrong["prediction"].map(_to_bool)).mean()
    _require_rate(numeric_wrong_reject, 0.99, "numeric_plus_unit should reject wrong_unit rows")

    answer_plausible = _subset(all_results, "answer_only", "plausible_scalar_wrong_unit")
    answer_false_accept = answer_plausible["prediction"].map(_to_bool).mean()
    _require_rate(
        answer_false_accept,
        0.95,
        "answer_only should false-accept plausible_scalar_wrong_unit rows",
    )

    numeric_conversion = _subset(all_results, "numeric_plus_unit", "missing_conversion")
    conversion_reject = (~numeric_conversion["prediction"].map(_to_bool)).mean()
    _require_rate(conversion_reject, 0.80, "numeric_plus_unit should reject most missing_conversion rows")


def _subset(results: pd.DataFrame, method: str, candidate_type: str) -> pd.DataFrame:
    subset = results[(results["method"] == method) & (results["candidate_type"] == candidate_type)]
    subset = subset[subset["prediction"].notna()].copy()
    if subset.empty:
        raise RuntimeError(f"Sanity check subset is empty: method={method}, candidate_type={candidate_type}")
    return subset


def _require_rate(rate: float, threshold: float, message: str) -> None:
    if not np.isfinite(rate) or rate < threshold:
        raise RuntimeError(f"Sanity check failed: {message}. Observed rate={rate:.3f}, threshold={threshold:.3f}")


def _binary_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict[str, float]:
    true_valid = y_true
    pred_valid = y_pred
    true_invalid = ~y_true
    pred_invalid = ~y_pred

    valid_tp = int(np.sum(true_valid & pred_valid))
    valid_fp = int(np.sum(true_invalid & pred_valid))
    valid_fn = int(np.sum(true_valid & pred_invalid))

    invalid_tp = int(np.sum(true_invalid & pred_invalid))
    invalid_fp = int(np.sum(true_valid & pred_invalid))
    invalid_fn = int(np.sum(true_invalid & pred_valid))

    valid_precision = safe_divide(valid_tp, valid_tp + valid_fp)
    valid_recall = safe_divide(valid_tp, valid_tp + valid_fn)
    valid_f1 = _f1(valid_precision, valid_recall)

    invalid_precision = safe_divide(invalid_tp, invalid_tp + invalid_fp)
    invalid_recall = safe_divide(invalid_tp, invalid_tp + invalid_fn)
    invalid_f1 = _f1(invalid_precision, invalid_recall)

    return {
        "accuracy": float(np.mean(y_true == y_pred)),
        "valid_precision": valid_precision,
        "valid_recall": valid_recall,
        "invalid_precision": invalid_precision,
        "invalid_recall": invalid_recall,
        "invalid_f1": invalid_f1,
        "macro_f1": float(np.nanmean([valid_f1, invalid_f1])),
        "false_accept_rate": safe_divide(invalid_fn, int(np.sum(true_invalid))),
        "false_reject_rate": safe_divide(valid_fn, int(np.sum(true_valid))),
    }


def _f1(precision: float, recall: float) -> float:
    if not np.isfinite(precision) or not np.isfinite(recall) or precision + recall == 0:
        return np.nan
    return 2 * precision * recall / (precision + recall)


def _empty_metrics() -> dict[str, float]:
    return {column: np.nan for column in METRIC_COLUMNS}


def _to_bool(value: object) -> bool:
    if isinstance(value, str):
        return value.lower() == "true"
    return bool(value)
