import numpy as np
import torch
import torch.nn.functional as F
from collections import defaultdict
from functools import partial
from module import recur
from .utils import (
        statistical_parity_max, 
        sparsity_eo, 
        odds_diffs_mean, 
        cal_pqi, cal_gini, 
        get_metrics, mclass_spspar, 
        ks_dist, w1_dist, 
        cal_pqi_pair, cal_gini_pair
                    )


def make_metric(split, **kwargs):
    data_name = kwargs['data_name']
    metric_name = {k: [] for k in split}

    if data_name in ['SimulateC', 'SimulateCM', 'Adult', 'AdultM']:
        best = -float('inf')
        best_direction = 'up'
        best_metric_name = 'Accuracy'
        for k in metric_name:
            metric_name[k].extend(
                 ['Accuracy']
                )
            if k == 'test':
                metric_name[k].extend(['SP', 'EO', 'SP (pqi, weak)', 'SP (gini, weak)', 
                                       'EO (pqi)', 'EO (gini)', 'EO (pqi, pos)', 'EO (gini, pos)'])
    elif data_name in ['SimulateR', 'LawSchool', 'Community']:
        best = -float('inf')
        best_direction = 'down'
        best_metric_name = 'MSE'
        for k in metric_name:
            metric_name[k].extend(['Loss', 'MSE'])
            if k == 'test':
                metric_name[k].extend(['RMSE', 'SP (pqi, weak)', 'SP (weak)', 'SP (ks)', 'SP (w1)', 
                                       'SP (pqi, int)', 'SP (pqi, max)', 'SP (gini, int)', 'SP (gini, max)',
                                       'EO', 'EO (pqi)', 'EO (gini)', 'EO (pqi, pos)', 'EO (gini, pos)'])
    else:
        raise ValueError('Not valid data name')
    metric = Metric(metric_name, best, best_direction, best_metric_name)
    return metric


def Accuracy(output, target, topk=1):
    try:
        with torch.no_grad():
            if output.dtype == torch.int64:
                acc = (output == target).float().mean().item()       
            else:
                if target.dtype != torch.int64:
                    target = (target.topk(1, -1, True, True)[1]).view(-1)
                batch_size = torch.numel(target)
                pred_k = output.topk(topk, -1, True, True)[1]
                correct_k = pred_k.eq(target.unsqueeze(-1).expand_as(pred_k)).float().sum()
                acc = (correct_k * (100.0 / batch_size)).item()
    except:
        acc = np.nan
    return acc


def MSE(output, target):
    with torch.no_grad():
        mse = F.mse_loss(output, target).item()
    return mse


class RMSE:
    def __init__(self):
        self.reset()

    def reset(self):
        self.se = torch.zeros((1,))
        self.count = torch.zeros((1,))
        return

    def add(self, input, output):
        self.se += F.mse_loss(output['target'], input['target'], reduction='sum')
        self.count += output['target'].numel()
        return

    def __call__(self, input, output):
        rmse = ((self.se / self.count) ** 0.5).item()
        self.reset()
        return rmse


class Fairness:
    def __init__(self, fair_metric):
        self.fair_metric = fair_metric
        self.reset()

    def reset(self):
        self.input_target = []
        self.output_target = []
        self.sensitive = []
        return

    def add(self, input, output):
        self.input_target.append(input['target'])
        self.output_target.append(output['target'])
        self.sensitive.append(input['sensitive'])
        return

    def __call__(self, input, output):
        input_target = torch.cat(self.input_target)
        output_target = torch.cat(self.output_target)
        sensitive = torch.cat(self.sensitive)
        # setting for p, q or metric_name
        setting = input.get('setting', {})
        fariness = self.fair_metric(output_target, input_target, sensitive, **setting)
        self.reset()
        return fariness


class Metric:
    def __init__(self, metric_name, best, best_direction, best_metric_name):
        self.metric_name = metric_name
        self.best, self.best_direction, self.best_metric_name = best, best_direction, best_metric_name
        self.metric = self.make_metric(metric_name)

    def make_metric(self, metric_name):
        metric = defaultdict(dict)
        for split in metric_name:
            for m in metric_name[split]:
                if m == 'Loss':
                    metric[split][m] = {'mode': 'batch', 'metric': (lambda input, output: output['loss'].item())}
        
                elif m == 'Accuracy':
                    metric[split][m] = {'mode': 'batch',
                                        'metric': (
                                            lambda input, output: recur(Accuracy, output['target'], input['target']))}
                elif m == 'MSE':
                    metric[split][m] = {'mode': 'batch',
                                        'metric': (
                                            lambda input, output: recur(MSE, output['target'], input['target']))}
                elif m == 'RMSE':
                    metric[split][m] = {'mode': 'full', 'metric': RMSE()}

                elif m == 'SP (weak)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(SP_r, mode = 'weak'))}
                elif m == 'SP (ks)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(SP_r, mode = 'ks'))}
                elif m == 'SP (w1)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(SP_r, mode = 'w1'))}
                elif m == 'SP':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(SP_c)}
                elif m == 'EO':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(EO)}
                elif m == 'SP (pqi, weak)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(SP_s_weak, m_sparsity='pqi'))}
                    
                elif m == 'SP (pqi, int)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(SP_s_strong, m_sparsity='pqi', mode = 'integral'))}
                elif m == 'SP (pqi, max)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(SP_s_strong, m_sparsity='pqi', mode = 'max'))}
                elif m == 'EO (pqi)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(EO_s, m_sparsity='pqi'))}
                
                elif m == 'EO (pqi, pos)': 
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(EO_s, m_sparsity='pqi', pos = True))}

                elif m == 'SP (gini, weak)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(SP_s_weak, m_sparsity='gini'))}

                elif m == 'SP (gini, int)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(SP_s_strong, m_sparsity='gini', mode = 'integral'))}
                elif m == 'SP (gini, max)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(SP_s_strong, m_sparsity='gini', mode = 'max'))}

                elif m == 'EO (gini)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(EO_s, m_sparsity='gini'))}
                elif m == 'EO (gini, pos)':
                    metric[split][m] = {'mode': 'full',
                                        'metric': Fairness(partial(EO_s, m_sparsity='gini', pos = True))}
                else:
                    raise ValueError('Not valid metric name')

        return metric

    def add(self, split, input, output):
        for metric_name in self.metric_name[split]:
            if self.metric[split][metric_name]['mode'] == 'full':
                self.metric[split][metric_name]['metric'].add(input, output)
        return

    def evaluate(self, split, mode, input, output, metric_name):
        evaluation = {}
        if not isinstance(input['target'], torch.Tensor):
            # convert to torch tensor for evaluation
            input['target'] = torch.tensor(input['target']).view(-1)
            input['sensitive'] = torch.tensor(input['sensitive']).view(-1)        
        
        for metric_name_i in metric_name[split]:
            if self.metric[split][metric_name_i]['mode'] == mode:
                evaluation[metric_name_i] = self.metric[split][metric_name_i]['metric'](input, output)
        return evaluation

    def compare(self, val, if_update):
        if self.best_direction == 'down':
            compared = self.best > val
        elif self.best_direction == 'up':
            compared = self.best < val
        else:
            raise ValueError('Not valid best direction')
        if compared and if_update:
            self.best = val
        return compared

    def load_state_dict(self, state_dict):
        self.best = state_dict['best']
        self.best_metric_name = state_dict['best_metric_name']
        self.best_direction = state_dict['best_direction']
        return

    def state_dict(self):
        return {'best': self.best, 'best_metric_name': self.best_metric_name, 'best_direction': self.best_direction}


def SP_r(output, target, sensitive, mode = 'ks', **kwargs):
    
    output = output.detach().cpu().numpy()
    target = target.detach().cpu().numpy()
    sensitive = sensitive.detach().cpu().numpy()
    if mode == 'ks':
        sp = ks_dist(output, sensitive)
    elif mode == 'w1':
        sp = w1_dist(output, sensitive)
    elif mode == 'weak':
        output_s = []
        for s in np.unique(sensitive):
            sp_s = np.mean(output[sensitive == s])
            output_s.append(sp_s)
        output_s = np.array(output_s)
        sp = np.max(output_s) - np.min(output_s)
    return sp


def SP_c(output, target, sensitive, **kwargs):
    if output.dtype != torch.int64:
        output = output.topk(1, 1, True, True)[1].ravel().type(torch.float64)
    
    sensitive = sensitive.ravel()
    output = output.detach().cpu().numpy()
    sensitive = sensitive.detach().cpu().numpy()

    # multi class multi group statistical parity, max over all classes
    sp = statistical_parity_max(output, sensitive)
    if len(np.unique(sensitive)) == 0:
        return 0

    return sp



def EO(output, target, sensitive, **kwargs):
    # implement equalized odds for regression (proposed)
    if target.dtype == torch.float32:
        sensitive = sensitive.ravel().detach().cpu().numpy()
        output = output.detach().cpu().numpy()
        target = target.detach().cpu().numpy()
        group = np.unique(sensitive)

        ns = len(group)
        mse = np.zeros((1, ns))
        for i in range(ns):
            y_s = target[sensitive == group[i]]
            y_pred_s = output[sensitive == group[i]]
            # get g(f_a(X), y_a) for each group
            mse[0, i] = np.mean((y_s - y_pred_s) ** 2)
        # maximum pairwise difference for vector mse
        return np.max(mse) - np.min(mse)
         
    elif target.dtype == torch.int64:
        if output.dtype != torch.int64:
            output = output.topk(1, 1, True, True)[1].ravel().type(torch.float64)
        sensitive = sensitive.ravel()
        output = output.detach().cpu().numpy()
        sensitive = sensitive.detach().cpu().numpy()
        target = target.detach().cpu().numpy()
        ns = len(np.unique(sensitive).astype(int))
        nc = len(np.unique(target).astype(int))

        if ns == 1:
            return 0
        odds = odds_diffs_mean(target, output, sensitive, ns, nc)
        
        odds = np.max(odds)
    return odds


def SP_s_weak(output, target, sensitive, m_sparsity='pqi', **kwargs):
    if target.dtype == torch.int64:
        # get predicted label
        if output.dtype != torch.int64:
            output = output.topk(1, 1, True, True)[1].ravel().type(torch.float64)
        
        output = output.detach().cpu().numpy()
        target = target.detach().cpu().numpy()
        sensitive = sensitive.detach().cpu().numpy()
        rate = get_metrics(target, output, sensitive, metric_name='rate')
        mean_sparsity = mclass_spspar(rate, m_sparsity=m_sparsity, **kwargs)
    
    elif target.dtype == torch.float32:
        output = output.detach().cpu().numpy()
        target = target.detach().cpu().numpy()
        sensitive = sensitive.detach().cpu().numpy()

        output_vec = []
        for s in np.unique(sensitive):
            output_vec.append(np.mean(output[sensitive == s]))
        output_vec = np.array(output_vec)
        if m_sparsity == 'pqi':
            # set p and q
            p = kwargs.get('p', 1)
            q = kwargs.get('q', 2)
            mean_sparsity = cal_pqi(output_vec, p = p, q = q)
        elif m_sparsity == 'gini':
            mean_sparsity = cal_gini(output_vec)
        else:
            raise ValueError('sparsity measure not implemented')
    else:
        raise ValueError('Not valid target')
    return mean_sparsity


def SP_s_strong(output, target, sensitive, m_sparsity='pqi', mode = 'integral', **kwargs):
    """ 
    strong SP for regression
    classification calculation is the same as SP_s_weak, differed by scale of # of classes
    """
    if target.dtype == torch.float32:
        output = output.detach().cpu().numpy()
        target = target.detach().cpu().numpy()
        sensitive = sensitive.detach().cpu().numpy()
        if m_sparsity == 'pqi':
            p = kwargs.get('p', 1)
            q = kwargs.get('q', 2)
            if mode == 'integral':
                # use empirical min and max to approximate the integral
                # summation of sparsity of each group weighted by step size
                cdf_pqi = cal_pqi_pair(output, sensitive, p, q)
                sorted_vec = np.sort(output, kind='mergesort')
                # NOTE: fix the integration domain by using target instead of output
                # Compute the differences between pairs of successive values of all elements
                
                deltas = np.diff(sorted_vec)
                results = np.sum(np.multiply(cdf_pqi, deltas))                   
            
            elif mode == 'max':
                results = np.max(cal_pqi_pair(output, sensitive, p, q))
        elif m_sparsity == 'gini':
            if mode == 'integral':
                cdf_gini = cal_gini_pair(output, sensitive)
                sorted_vec = np.sort(output, kind='mergesort')
                deltas = np.diff(sorted_vec)
                results = np.sum(np.multiply(cdf_gini, deltas))
            elif mode == 'max':
                results = np.max(cal_gini_pair(output, sensitive))   
            
    else:
        raise ValueError('Not valid target')
    return results


def EO_s(output, target, sensitive, m_sparsity='pqi', pos = False, **kwargs):
    sensitive = sensitive.ravel()
    output = output.detach().cpu().numpy()
    target = target.detach().cpu().numpy()
    sensitive = sensitive.detach().cpu().numpy()
    # calculate eo sparsity
    nll_sparsity = sparsity_eo(target, output, sensitive, m_sparsity, pos,  **kwargs)

    return nll_sparsity

