from __future__ import annotations

import math
from typing import Any, Dict, List

import numpy as np


def _precision_at_k(scores: np.ndarray, labels: np.ndarray, k: int) -> float:
    if k <= 0 or scores.size == 0:
        return float("nan")
    k = min(int(k), int(scores.size))
    order = np.argsort(-scores)[:k]
    return float(np.mean(labels[order])) if order.size else float("nan")


def _roc_auc_score_binary(y_true: np.ndarray, y_score: np.ndarray) -> float:
    """Compute ROC AUC for binary labels without sklearn.

    Uses the Mann–Whitney U / rank-sum formulation with average ranks for ties.
    """
    y = np.asarray(y_true, dtype=bool).reshape(-1)
    s = np.asarray(y_score, dtype=float).reshape(-1)
    if y.shape[0] != s.shape[0]:
        raise ValueError("y_true and y_score must have the same length")

    n = int(y.shape[0])
    n_pos = int(y.sum())
    n_neg = n - n_pos
    if n_pos == 0 or n_neg == 0:
        raise ValueError("ROC AUC is undefined when only one class is present")

    order = np.argsort(s, kind="mergesort")
    s_sorted = s[order]

    ranks = np.empty(n, dtype=float)
    i = 0
    while i < n:
        j = i + 1
        while j < n and s_sorted[j] == s_sorted[i]:
            j += 1
        # Average rank for ties; ranks are 1-indexed.
        avg_rank = 0.5 * ((i + 1) + j)
        ranks[order[i:j]] = avg_rank
        i = j

    sum_pos_ranks = float(ranks[y].sum())
    auc = (sum_pos_ranks - (n_pos * (n_pos + 1) / 2.0)) / float(n_pos * n_neg)
    return float(auc)


def _average_precision_score_binary(y_true: np.ndarray, y_score: np.ndarray) -> float:
    """Compute Average Precision (AUPRC) for binary labels without sklearn.

    This matches the standard definition: mean precision at ranks of positives
    when sorting by decreasing score.
    """
    y = np.asarray(y_true, dtype=bool).reshape(-1)
    s = np.asarray(y_score, dtype=float).reshape(-1)
    if y.shape[0] != s.shape[0]:
        raise ValueError("y_true and y_score must have the same length")

    n_pos = int(y.sum())
    n = int(y.shape[0])
    n_neg = n - n_pos
    if n_pos == 0 or n_neg == 0:
        raise ValueError("Average precision is undefined when only one class is present")

    order = np.argsort(-s, kind="mergesort")
    y_sorted = y[order]

    tp = np.cumsum(y_sorted, dtype=float)
    denom = np.arange(1, y_sorted.size + 1, dtype=float)
    precision = tp / denom

    ap = float(precision[y_sorted].mean()) if n_pos > 0 else float("nan")
    return ap


def poison_detection_report_v2(
    scores: np.ndarray,
    poison_mask: np.ndarray,
    topk_list: List[int],
    *,
    candidate_only: bool = False,
    nonfinite_margin: float = 1.0,
) -> Dict[str, Any]:
    """Compute auditable poison-detection metrics.

    Semantics:
    - Default (candidate_only=False): evaluate on the full dataset by ranking
      non-finite scores (NaN/Inf) as worst.
    - candidate_only=True: evaluate only on finite-scored candidates (explicit).
    - Score direction is chosen to maximize AUROC (fallback to AUPRC).
    """

    s = np.asarray(scores, dtype=float).reshape(-1)
    y = np.asarray(poison_mask, dtype=bool).reshape(-1)
    if s.shape[0] != y.shape[0]:
        raise ValueError("scores and poison_mask must have the same length")

    n_total_raw = int(s.shape[0])
    finite = np.isfinite(s)
    n_finite = int(finite.sum())
    n_nan = int(np.isnan(s).sum())
    n_inf = int(np.isinf(s).sum())

    poison_raw = int(y.sum())
    poison_finite = int(y[finite].sum()) if n_finite > 0 else 0

    prevalence_raw = float(poison_raw / max(1, n_total_raw))
    prevalence_finite = float(poison_finite / n_finite) if n_finite > 0 else float("nan")
    coverage = float(n_finite / max(1, n_total_raw))

    nonfinite_fill_value: float
    if candidate_only:
        y_eval = y[finite]
        s_eval = np.asarray(s[finite], dtype=float) if n_finite > 0 else np.asarray([], dtype=float)
        nonfinite_fill_value = float("nan")
    else:
        if n_finite > 0:
            min_finite = float(np.min(s[finite]))
            margin = float(nonfinite_margin)
            if not math.isfinite(margin) or margin <= 0.0:
                margin = 1.0
            nonfinite_fill_value = min_finite - margin
            s_eval = np.asarray(s, dtype=float).copy()
            s_eval[~finite] = nonfinite_fill_value
        else:
            nonfinite_fill_value = 0.0
            s_eval = np.zeros_like(s, dtype=float)
        y_eval = y

    n_eval = int(y_eval.shape[0])
    n_pos = int(y_eval.sum())
    n_neg = n_eval - n_pos

    report: Dict[str, Any] = {
        "candidate_only": bool(candidate_only),
        "n_total_raw": n_total_raw,
        "n_eval": n_eval,
        "n_finite": n_finite,
        "n_dropped_nonfinite": int(n_total_raw - n_finite),
        "n_nan": n_nan,
        "n_inf": n_inf,
        "coverage": coverage,
        "poison_raw": poison_raw,
        "poison_finite": poison_finite,
        "prevalence_raw": prevalence_raw,
        "prevalence_finite": prevalence_finite,
        "nonfinite_fill_value": float(nonfinite_fill_value) if (not candidate_only and n_finite > 0) else float("nan"),
        "chosen_sign": 1.0,
        "auroc": float("nan"),
        "auprc": float("nan"),
        "auroc_pos": float("nan"),
        "auroc_neg": float("nan"),
        "auprc_pos": float("nan"),
        "auprc_neg": float("nan"),
        "precision_at_k": {},
        "mean_rank_poison": float("nan"),
        # Backward-compatible field (deprecated)
        "n_poison": poison_raw,
    }

    if n_pos == 0 or n_neg == 0:
        return report

    def _safe_roc(yb: np.ndarray, sb: np.ndarray) -> float:
        try:
            return float(_roc_auc_score_binary(yb, sb))
        except Exception:
            return float("nan")

    def _safe_ap(yb: np.ndarray, sb: np.ndarray) -> float:
        try:
            return float(_average_precision_score_binary(yb, sb))
        except Exception:
            return float("nan")

    auroc_pos = _safe_roc(y_eval, s_eval)
    auroc_neg = _safe_roc(y_eval, -s_eval)
    auprc_pos = _safe_ap(y_eval, s_eval)
    auprc_neg = _safe_ap(y_eval, -s_eval)

    report["auroc_pos"] = auroc_pos
    report["auroc_neg"] = auroc_neg
    report["auprc_pos"] = auprc_pos
    report["auprc_neg"] = auprc_neg

    choose_by_roc = math.isfinite(auroc_pos) and math.isfinite(auroc_neg)
    if choose_by_roc:
        chosen_sign = 1.0 if auroc_pos >= auroc_neg else -1.0
    else:
        chosen_sign = 1.0 if (not math.isfinite(auprc_neg) or auprc_pos >= auprc_neg) else -1.0

    report["chosen_sign"] = float(chosen_sign)

    s_rank = chosen_sign * s_eval
    report["auroc"] = _safe_roc(y_eval, s_rank)
    report["auprc"] = _safe_ap(y_eval, s_rank)

    try:
        order = np.argsort(-s_rank)
        ranks = np.empty_like(order)
        ranks[order] = np.arange(order.size)
        poison_ranks = ranks[y_eval]
        if poison_ranks.size:
            report["mean_rank_poison"] = float(poison_ranks.mean())
    except Exception:
        report["mean_rank_poison"] = float("nan")

    report["precision_at_k"] = {
        int(k): float(_precision_at_k(np.asarray(s_rank, dtype=float), y_eval.astype(float), int(k))) for k in topk_list
    }

    return report
