from typing import Dict, Union

import numpy as np
import sklearn
import sklearn.metrics
import torch
from torch import Tensor


def fpr_at_fixed_tpr(fprs: np.ndarray, tprs: np.ndarray, thresholds: np.ndarray, tpr_level: float = 0.95):
    # return np.interp(tpr_level, tprs, fprs)
    if all(tprs < tpr_level):
        raise ValueError(f"No threshold allows for TPR at least {tpr_level}.")
    idxs = [i for i, x in enumerate(tprs) if x >= tpr_level]
    if len(idxs) == 0:
        idx = 0
    else:
        idx = min(idxs)
    return float(fprs[idx]), float(tprs[idx]), float(thresholds[idx])


def fnr_at_fixed_tnr(fprs: np.ndarray, tprs: np.ndarray, thresholds: np.ndarray, tnr_level: float = 0.95):
    tnrs = 1 - fprs
    fnrs = 1 - tprs

    if all(tnrs < tnr_level):
        raise ValueError(f"No threshold allows for TNR at least {tnr_level}.")
    idxs = [i for i, x in enumerate(tnrs) if x >= tnr_level]
    idx = min(idxs)
    return float(fnrs[idx]), float(tnrs[idx]), float(thresholds[idx])


def compute_detection_error(fpr: float, tpr: float, pos_ratio: float):
    # Get ratios of positives to negatives
    neg_ratio = 1 - pos_ratio
    # Get indexes of all TPR >= fixed tpr level
    detection_error = pos_ratio * (1 - tpr) + neg_ratio * fpr
    return detection_error


def minimum_detection_error(fprs: np.ndarray, tprs: np.ndarray, pos_ratio: float):
    detection_errors = [compute_detection_error(fpr, tpr, pos_ratio) for fpr, tpr in zip(fprs, tprs)]
    idx = np.argmin(detection_errors)
    return detection_errors[idx]


def get_ood_results(in_scores: Union[Tensor, np.ndarray], ood_scores: Union[Tensor, np.ndarray]) -> Dict[str, float]:
    if isinstance(in_scores, np.ndarray) or isinstance(in_scores, list):
        in_scores = torch.tensor(in_scores)
    if isinstance(ood_scores, np.ndarray) or isinstance(ood_scores, list):
        ood_scores = torch.tensor(ood_scores)
    in_labels = torch.ones(len(in_scores))
    ood_labels = torch.zeros(len(ood_scores))

    _test_scores = torch.cat([in_scores, ood_scores]).cpu().numpy()
    _test_labels = torch.cat([in_labels, ood_labels]).cpu().numpy()

    fprs, tprs, thrs = sklearn.metrics.roc_curve(_test_labels, _test_scores)
    fpr, tpr, thr = fpr_at_fixed_tpr(fprs, tprs, thrs, 0.95)
    auroc = sklearn.metrics.auc(fprs, tprs)

    precision, recall, _ = sklearn.metrics.precision_recall_curve(_test_labels, _test_scores, pos_label=1)
    precision_out, recall_out, _ = sklearn.metrics.precision_recall_curve(_test_labels, _test_scores, pos_label=0)
    aupr_in = sklearn.metrics.auc(recall, precision)
    aupr_out = sklearn.metrics.auc(recall_out, precision_out)
    f1 = sklearn.metrics.f1_score(_test_labels, _test_scores > thr)

    pos_ratio = np.mean(_test_labels == 1)
    detection_error = minimum_detection_error(fprs, tprs, pos_ratio)

    # aufnr, aufpr, autc = aufnr_aufpr_autc(fprs, tprs, thrs)

    results = {
        "fpr_at_0.95_tpr": fpr,
        "detection_error": detection_error,
        "auroc": auroc,
        "aupr_in": aupr_in,
        "aupr_out": aupr_out,
        "f1": f1,
        "thr": thr,
    }
    return results


METRICS_NAMES_PRETTY = {
    "fpr_at_0.95_tpr": "FPR at 95% TPR",
    "tnr_at_0.95_tpr": "TNR at 95% TPR",
    "detection_error": "Detection error",
    "auroc": "AUROC",
    "aupr_in": "AUPR in",
    "aupr_out": "AUPR out",
    "f1": "F1",
    "thr": "Threshold",
    "time": "Time",
}
