from .pred_rej_area import PredictionRejectionArea
from .rev_pairs_prop import ReversedPairsProportion
from .risk_cov_curve import RiskCoverageCurveAUC
from .spearmanr import SpearmanRankCorrelation
from .kendalltau import KendallTauCorrelation
from .roc_auc import ROCAUC
from .pr_auc import PRAUC


import numpy as np

from typing import List

from lm_polygraph.ue_metrics.ue_metric import UEMetric, normalize


class PredictionRejectionAreaNormalized(UEMetric):
    """
    Calculates area under Prediction-Rejection curve.
    """

    def __init__(self, max_rejection: float = 1.0):
        """
        Parameters:
            max_rejection (float): a maximum proportion of instances that will be rejected.
                1.0 indicates entire set, 0.5 - half of the set
        """
        super().__init__()
        self.max_rejection = max_rejection

    def __str__(self):
        if self.max_rejection == 1:
            return "prr_norm"
        return f"prr_norm_{self.max_rejection}"

    def __call__(self, estimator: List[float], target: List[float]) -> float:
        """
        Measures the area under the Prediction-Rejection curve between `estimator` and `target`.

        Parameters:
            estimator (List[int]): a batch of uncertainty estimations.
                Higher values indicate more uncertainty.
            target (List[int]): a batch of ground-truth uncertainty estimations.
                Higher values indicate more uncertainty.
        Returns:
            float: area under the Prediction-Rejection curve.
                Higher values indicate better uncertainty estimations.
        """
        nnan = ~np.isnan(target)
        estimator = np.array([e for e, m in zip(estimator, nnan) if m])
        target = np.array([t for t, m in zip(target, nnan) if m])
        target = 1 - target
        target = normalize(target)
        # ue: greater is more uncertain
        ue = np.array(estimator)
        ue_argsort = np.argsort(ue)
        target_argsort = np.argsort(target)[::-1]
        accs, oracle_accs = [], []
        for num_rejection in range(int(self.max_rejection * len(ue))):
            accs.append(target[ue_argsort[:len(ue) - num_rejection]].mean())
            oracle_accs.append(target[target_argsort[:len(ue) - num_rejection]].mean())
        random_acc = target.mean()
        return (np.mean(accs) - random_acc) / (np.mean(oracle_accs) - random_acc)