import numpy as np
from typing import Callable, NamedTuple, Dict
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, average_precision_score, balanced_accuracy_score

from dataset import CompData
from model import IndivFairModel


class FairRes(NamedTuple):
    mean: float
    q1: float  # lower quartile
    q3: float  # upper quartile
    N: int  # number of comparable sample pairs


class UtilRes(NamedTuple):
    roc: float
    ap: float
    acc: float
    bal_acc: float
    f1: float
    label_vio: float


class Evaluator():
    """ Evaluation of utility and individual fairness """

    def __init__(self, X: np.ndarray, y: np.ndarray, comp_data: CompData):
        self.X = X
        self.y = y
        self.comp_data = comp_data

    def _eval_print(self, memo: Dict) -> None:
        print(
            "===== Evaluation =====\n"
            "ROC: {:.2f}%; AP: {:.2f}%; Acc.: {:.2f}%; Bal. Acc.: {:.2f}%; F1: {:.2f}%; Vio.: {:.2f}%\n"
            # "And : all - {:.2f}%|{:.2f}%|{:.2f}%|{:d}; true - {:.2f}%|{:.2f}%|{:.2f}%|{:d}; false - {:.2f}%|{:.2f}%|{:.2f}%|{:d}\n"
            "Or : all - {:.2f}%|{:.2f}%|{:.2f}%|{:d}; true - {:.2f}%|{:.2f}%|{:.2f}%|{:d}; false - {:.2f}%|{:.2f}%|{:.2f}%|{:d}"
            # "Not : all - {:.2f}%|{:.2f}%|{:.2f}%|{:d}; true - {:.2f}%|{:.2f}%|{:.2f}%|{:d}; false - {:.2f}%|{:.2f}%|{:.2f}%|{:d}"
                .format(
                *memo["util"],
                # *memo["and"]["all"], *memo["and"]["true"], *memo["and"]["false"],
                *memo["or"]["all"], *memo["or"]["true"], *memo["or"]["false"],
                # *memo["not"]["all"], *memo["not"]["true"], *memo["not"]["false"],
            ))

        return

    def __call__(self, pred_func: Callable, pred_prob_func: Callable) -> Dict:
        pred = pred_func(self.X)
        pred_prob = pred_prob_func(self.X)

        roc = roc_auc_score(self.y, pred_prob)
        ap = average_precision_score(self.y, pred_prob)
        acc = accuracy_score(self.y, pred)
        bal_acc = balanced_accuracy_score(self.y, pred)
        f1 = f1_score(self.y, pred)

        memo = {"util": None, "and": {}, "or": {}, "not": {}}
        res = {"roc": roc, "ap": ap, "acc": acc, "bal_acc": bal_acc, "f1": f1}
        for loader in self.comp_data.loaders:
            gap = []
            for batch_1, batch_2 in loader:
                gap.extend(np.abs(pred_prob_func(batch_1) - pred_prob_func(batch_2)))

            cond_true_idx, cond_false_idx = loader.cond_idx

            gap = np.asarray(gap)
            cond_true = gap.take(cond_true_idx)
            cond_false = gap.take(cond_false_idx)

            if len(gap) == 0:
                gap = np.array([np.nan])
            if len(cond_true) == 0:
                cond_true = np.array([np.nan])
            if len(cond_false) == 0:
                cond_false = np.array([np.nan])

            for name, val in zip(("all", "true", "false"), (gap, cond_true, cond_false)):
                fair_res = FairRes(
                    np.mean(val) * 100,
                    np.quantile(val, q=0.25) * 100,
                    np.quantile(val, q=0.75) * 100,
                    len(val),
                )
                memo[loader.name][name] = fair_res

            res[loader.name] = {"all": gap.tolist(), "true": cond_true.tolist(), "false": cond_false.tolist()}

        label_vio = (memo["or"]["true"].N + memo["or"]["false"].N + memo["not"]["true"].N + memo["not"]["true"].N) \
                    / (memo["or"]["all"].N + memo["not"]["all"].N)
        label_vio = 1 - label_vio

        memo["util"] = UtilRes(roc * 100, ap * 100, acc * 100, bal_acc * 100, f1 * 100, label_vio * 100)
        res["label_vio"] = label_vio

        self._eval_print(memo)

        return res


class LearnEfficacy():
    """ Evaluate machine learning efficacy with synthetic data """

    def __init__(self, model: IndivFairModel):
        self.model = model

    def __call__(self, train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray) -> None:
        self.model.fit(train_X, train_y)
        pred = self.model.pred(test_X)
        pred_prob = self.model.pred_proba(test_X)

        acc = accuracy_score(test_y, pred)
        f1 = f1_score(test_y, pred)
        roc = roc_auc_score(test_y, pred_prob)

        print(
            "===== Machine Learning Efficacy =====\n"
            "Accuracy: {:.2f}%; F1 Score: {:.2f}%; ROC Score: {:.2f}%;\n"
                .format(acc * 100, f1 * 100, roc * 100)
        )

        return
