import itertools
import warnings
import torch

from sklearn import metrics as sklearn_metrics

from ..utils.misc import all_gather

warnings.filterwarnings('ignore')


class FairnessEvaluator:
    def __init__(self, metrics: list):
        self.metrics = metrics
        self.outputs = []
        self.targets = []
        self.sensitivity = []
        self.eval = {metric: None for metric in metrics}
        self.matrix = {
            'sum': 0,
            'T': 0,
            's0': 0,
            's1': 0,
            'T_s0': 0,
            'T_s1': 0,
            'p0_s0': 0,
            'p0_s1': 0,
            'p1_s0': 0,
            'p1_s1': 0,
            't0_s0': 0,
            't0_s1': 0,
            't1_s0': 0,
            't1_s1': 0,
            'p0_t0_s0': 0,
            'p0_t0_s1': 0,
            'p1_t0_s0': 0,
            'p1_t0_s1': 0,
            'p0_t1_s0': 0,
            'p0_t1_s1': 0,
            'p1_t1_s0': 0,
            'p1_t1_s1': 0
        }

    def update(self, outputs, targets, sensitivity):
        if isinstance(outputs, dict):
            assert 'logits' in outputs.keys(), \
                f"When using 'update(self, outputs, targets)' in '{self.__class__.__name__}', " \
                f"if 'outputs' is a dict, 'logits' MUST be the key."
            outputs = outputs['logits']
        outputs = outputs.max(1)[1]

        self.matrix['sum'] += 1.*targets.size(0) 

        self.matrix['T'] += (outputs == targets).float().sum() # TP+TN
        self.matrix['s0'] += (sensitivity == 0).float().sum()  # N_s
        self.matrix['s1'] += (sensitivity == 1).float().sum()  # P_s
        self.matrix['T_s0'] += ((outputs == targets) & (sensitivity == 0)).float().sum()  # TN+TP | N_s 
        self.matrix['T_s1'] += ((outputs == targets) & (sensitivity == 1)).float().sum()  # TN+TP | P_s

        # For all prerequisites
        self.matrix['p0_s0'] += ((outputs == 0) & (sensitivity == 0)).float().sum()  # N_o | N_s
        self.matrix['p0_s1'] += ((outputs == 0) & (sensitivity == 1)).float().sum()  # N_o | P_s
        self.matrix['p1_s0'] += ((outputs == 1) & (sensitivity == 0)).float().sum()  # P_o | N_s
        self.matrix['p1_s1'] += ((outputs == 1) & (sensitivity == 1)).float().sum()  # P_o | P_s
        self.matrix['t0_s0'] += ((targets == 0) & (sensitivity == 0)).float().sum()  # N_r | N_s
        self.matrix['t0_s1'] += ((targets == 0) & (sensitivity == 1)).float().sum()  # N_r | P_s
        self.matrix['t1_s0'] += ((targets == 1) & (sensitivity == 0)).float().sum()  # P_r | N_s
        self.matrix['t1_s1'] += ((targets == 1) & (sensitivity == 1)).float().sum()  # P_r | P_s

        # For all possible scenarios
        self.matrix['p0_t0_s0'] += ((outputs == 0) & (targets == 0) & (sensitivity == 0)).float().sum()  # N_o | N_r | N_s
        self.matrix['p0_t0_s1'] += ((outputs == 0) & (targets == 0) & (sensitivity == 1)).float().sum()  # N_o | N_r | P_s
        self.matrix['p0_t1_s0'] += ((outputs == 0) & (targets == 1) & (sensitivity == 0)).float().sum()  # N_o | P_r | N_s
        self.matrix['p0_t1_s1'] += ((outputs == 0) & (targets == 1) & (sensitivity == 1)).float().sum()  # N_o | P_r | P_s
        self.matrix['p1_t0_s0'] += ((outputs == 1) & (targets == 0) & (sensitivity == 0)).float().sum()  # P_o | N_r | N_s
        self.matrix['p1_t0_s1'] += ((outputs == 1) & (targets == 0) & (sensitivity == 1)).float().sum()  # P_o | N_r | P_s
        self.matrix['p1_t1_s0'] += ((outputs == 1) & (targets == 1) & (sensitivity == 0)).float().sum()  # P_o | P_r | N_s
        self.matrix['p1_t1_s1'] += ((outputs == 1) & (targets == 1) & (sensitivity == 1)).float().sum()  # P_o | P_r | P_s

        outputs = outputs.tolist()
        targets = targets.tolist()
        sensitivity = sensitivity.tolist()
        self.outputs += outputs
        self.targets += targets
        self.sensitivity += sensitivity

    def synchronize_between_processes(self):
        self.outputs = list(itertools.chain(*all_gather(self.outputs)))
        self.targets = list(itertools.chain(*all_gather(self.targets)))
        self.sensitivity = list(itertools.chain(*all_gather(self.sensitivity)))

    @staticmethod
    def metric_acc(outputs, targets, sensitivity, matrix, **kwargs):
        ACC = matrix['T'].float()/matrix['sum'] * 100  # ACC = (TN+TP)/(TN+TP+FN+FP)
        return ACC.item()

    @staticmethod
    def metric_2ed_indicators(outputs, targets, sensitivity, matrix, **kwargs):
        d_PPV = abs(matrix['p1_t1_s0'] / (matrix['p1_s0'] + 1e-6) - matrix['p1_t1_s1'] / (matrix['p1_s1'] + 1e-6)) * 100  # PPV = TP / (TP + FP)
        d_NPV = abs(matrix['p0_t0_s0'] / (matrix['p0_s0'] + 1e-6) - matrix['p0_t0_s1'] / (matrix['p0_s1'] + 1e-6)) * 100  # NPV = TN / (TN + FN)
        d_TPR = abs(matrix['p1_t1_s0'] / (matrix['t1_s0'] + 1e-6) - matrix['p1_t1_s1'] / (matrix['t1_s1'] + 1e-6)) * 100  # TPR = TP / (TP + FN)
        d_TNR = abs(matrix['p0_t0_s0'] / (matrix['t0_s0'] + 1e-6) - matrix['p0_t0_s1'] / (matrix['t0_s1'] + 1e-6)) * 100  # TNR = TN / (TN + FP)
        d_FPR = abs(matrix['p1_t0_s0'] / (matrix['t0_s0'] + 1e-6) - matrix['p1_t0_s1'] / (matrix['t0_s1'] + 1e-6)) * 100  # FPR = FP / (FP + TN)
        d_FNR = abs(matrix['p0_t1_s0'] / (matrix['t1_s0'] + 1e-6) - matrix['p0_t1_s1'] / (matrix['t1_s1'] + 1e-6)) * 100  # FNR = FN / (FN + TP)
        d_FDR = abs(matrix['p1_t0_s0'] / (matrix['p1_s0'] + 1e-6) - matrix['p1_t0_s1'] / (matrix['p1_s1'] + 1e-6)) * 100  # FDR = FP / (TP + FP)
        d_FOR = abs(matrix['p0_t1_s0'] / (matrix['p0_s0'] + 1e-6) - matrix['p0_t1_s1'] / (matrix['p0_s1'] + 1e-6)) * 100  # FOR = FN / (TN + FN)
        return d_PPV.item(), d_NPV.item(), d_TPR.item(), d_TNR.item(), d_FPR.item(), d_FNR.item(), d_FDR.item(), d_FOR.item()

    @staticmethod
    # For DDP metric  -- Demographic Disparity in Predicted Positive: DDP = | FPRs0 - FPRs1 |
    def metric_dp(outputs, targets, sensitivity, matrix, **kwargs):
        d_DP = abs(matrix['p1_s0'] / (matrix['s0'] + 1e-6) - matrix['p1_s1'] / (matrix['s1'] + 1e-6)) * 100
        return d_DP.item()

    @staticmethod
    def metric_eopp(outputs, targets, sensitivity, matrix, **kwargs):
        d_FNR = abs(matrix['p0_t1_s0'] / (matrix['t1_s0'] + 1e-6) - matrix['p0_t1_s1'] / (matrix['t1_s1'] + 1e-6)) * 100  # FNR = FN / (FN + TP)
        d_Eopp = d_FNR
        return d_Eopp.item()
    
    # @staticmethod
    # def metric_eodd(outputs, targets, sensitivity, matrix, **kwargs):
    #     d_TPR = abs(matrix['p1_t1_s0'] / (matrix['t1_s0'] + 1e-6) - matrix['p1_t1_s1'] / (matrix['t1_s1'] + 1e-6)) * 100  # TPR = TP / (TP + FN)
    #     d_FPR = abs(matrix['p1_t0_s0'] / (matrix['t0_s0'] + 1e-6) - matrix['p1_t0_s1'] / (matrix['t0_s1'] + 1e-6)) * 100  # FPR = FP / (FP + TN)
    #     d_Eodd = max(d_TPR, d_FPR) 
    #     return d_Eodd.item()
    
    @staticmethod
    def metric_eodd(outputs, targets, sensitivity, matrix, **kwargs):
        d_TPR = abs(matrix['p1_t1_s0'] / (matrix['t1_s0'] + 1e-6) - matrix['p1_t1_s1'] / (matrix['t1_s1'] + 1e-6)) * 100  # TPR = TP / (TP + FN)
        d_FPR = abs(matrix['p1_t0_s0'] / (matrix['t0_s0'] + 1e-6) - matrix['p1_t0_s1'] / (matrix['t0_s1'] + 1e-6)) * 100  # FPR = FP / (FP + TN)
        d_Eodd = max(d_TPR, d_FPR) 
        return d_Eodd.item()
    
    @staticmethod
    def metric_aod(outputs, targets, sensitivity, matrix, **kwargs):
        d_FPR = abs(matrix['p1_t0_s0'] / (matrix['t0_s0'] + 1e-6) - matrix['p1_t0_s1'] / (matrix['t0_s1'] + 1e-6)) * 100  # FPR = FP / (FP + TN)
        d_TPR = abs(matrix['p1_t1_s0'] / (matrix['t1_s0'] + 1e-6) - matrix['p1_t1_s1'] / (matrix['t1_s1'] + 1e-6)) * 100  # TPR = TP / (TP + FN)
        d_Aodd = 0.5*(d_TPR + d_FPR)
        return d_Aodd.item()

    def summarize(self):
        print('Fairness Classification Metrics:')
        for metric in self.metrics:
            value = getattr(self, f'metric_{metric}')(self.outputs, self.targets, self.sensitivity, self.matrix)
            self.eval[metric] = round(value, 2)
            print('{}:{:.3f}'.format(metric, value), end='    ')
        print('\n')