import torch
import numpy as np
import json
import torch.nn.functional as F
from sklearn.metrics import accuracy_score


class FairnessMetrics:
    def __init__(self, averaged_over, subgroup_idx, subgroup_minority, eval_every=5):
        """
        Implements evaluation metrics for measuring performance of the VFair model.

        Args:
            averaged_over: the amount of iterations the model is ran (for averaging 
                     the results).
            eval_every: the amount of steps between evaluation of the model.
        """
        self.logging_dict = {"acc": 0}
        self.eval_every = eval_every

        self.acc = [[] for i in range(averaged_over)]
        self.var = [[] for i in range(averaged_over)]
        self.worst = [[] for i in range(averaged_over)]
        self.diff = [[] for i in range(averaged_over)]
        self.sum = [[] for i in range(averaged_over)]
        self.group1 = [[] for i in range(averaged_over)]
        self.group2 = [[] for i in range(averaged_over)]
        self.group3 = [[] for i in range(averaged_over)]
        self.group4 = [[] for i in range(averaged_over)]

        self.subgroup_indexes = subgroup_idx
        self.subgroup_minority = subgroup_minority


    def set_acc(self, pred, targets, n_iter):
        """
        Calculates the accuracy score.

        Args:
            pred: prediction (Torch tensor).
            targets: target variables (Torch tensor).
            n_iter: iteration of this training loop. 
        """

        acc = accuracy_score(targets.cpu().detach().numpy(), pred.cpu().detach().numpy())
        self.acc[n_iter].append(acc)
        self.logging_dict["acc"] = acc

        return acc

    def average_results(self):
        """
        Averages the results of all iterations.
        """

        self.acc_avg = np.mean(np.array(self.acc), axis=0)
        self.var_avg = np.mean(np.array(self.var), axis=0)
        self.worst_avg = np.mean(np.array(self.worst), axis=0)
        self.diff_avg = np.mean(np.array(self.diff), axis=0)
        self.sum_avg = np.mean(np.array(self.sum), axis=0)

        self.group1_avg = np.mean(np.array(self.group1), axis=0)
        self.group2_avg = np.mean(np.array(self.group2), axis=0)
        self.group3_avg = np.mean(np.array(self.group3), axis=0)
        self.group4_avg = np.mean(np.array(self.group4), axis=0)


    def save_metrics(self, res_dir):
        """
        Saves the averaged metrics in a json file.
        """

        metrics = {
            "acc_avg": list(self.acc_avg),
            "var_avg": float(self.var_avg[-1]),
            "worst_avg": list(self.worst_avg),
            "diff_avg": list(self.diff_avg),
            "sum_avg": list(self.sum_avg),
            # "group1_avg": float(self.group1_avg[-1]),
            # "group2_avg": float(self.group2_avg[-1]),
            # "group3_avg": float(self.group3_avg[-1]),
            # "group4_avg": float(self.group4_avg[-1]),
        }
        json.dump(metrics, open("{}.json".format(res_dir), 'w'))


    def set_var(self, pred, targets, n_iters):
        # logits
        if pred.shape[1] == 1:
            loss = F.binary_cross_entropy_with_logits(pred.squeeze(), targets.to('cuda'), reduction="none")
        else:
            loss = F.cross_entropy(pred, targets, reduction='none')
        variance = torch.var(loss).cpu().detach().numpy()
        self.var[n_iters].append(variance)
        self.logging_dict["var"] = variance


    def set_acc_other(self, pred, targets, n_iter):
        accs = []
        for group_idx in self.subgroup_indexes:
            group_pred = pred[group_idx].cpu().detach().numpy()
            group_tar = targets[group_idx].cpu().detach().numpy()
            group_acc = torch.tensor(accuracy_score(group_tar, group_pred))
            accs.append(group_acc)
        accs = torch.tensor(accs)
        miu = torch.mean(accs)
        tad = torch.tensor([torch.abs(acc - miu) for acc in accs])
        tad = torch.sum(tad)
        self.worst[n_iter].append(torch.min(accs))
        self.diff[n_iter].append(torch.max(accs) - torch.min(accs))
        self.sum[n_iter].append(tad)
        # self.group1[n_iter].append(accs[0])
        # self.group2[n_iter].append(accs[1])
        # self.group3[n_iter].append(accs[2])
        # self.group4[n_iter].append(accs[3])

        self.logging_dict["worst"] = torch.min(accs)
        self.logging_dict["diff"] = torch.max(accs) - torch.min(accs)
        self.logging_dict["sum"] = tad
        return

