import numpy as np
from itertools import combinations
import torch
from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix, f1_score, roc_auc_score
from scipy import stats
from config import cfg
import torch.nn.functional as F



def statistical_parity_max(y, s):
    """For all y and multiple s"""
    ns = len(np.unique(s))
    nc = len(np.unique(y))
    sp = np.zeros((ns, nc))
    for i in range(ns):
        for j in range(nc):
            if len(y[s == i]) != 0:
                sp[i, j] = len(y[np.logical_and(s == i, y == j)]) / len(y[s == i])

    sp_class = []
    for j in range(nc):
        sp_class.append(max([np.abs(a1 - a2) for (a1, a2) in combinations(sp[:, j], 2)]))

    return np.max(sp_class)


def confusion(y, y_pred):
    TN, FP, FN, TP = confusion_matrix(y, y_pred).ravel()

    TPR = TP / (TP + FN)
    FPR = FP / (FP + TN)
    return TPR, FPR


def multiclass_confusion(y, y_pred, nc):
    offset = 1
    cm = multilabel_confusion_matrix(y, y_pred) 
    tprs, fprs = np.zeros((nc)), np.zeros((nc))
    # loop classes
    for i in range(nc):
        tn, fp, fn, tp = cm[i, :, :].ravel()
        tpr_d = (tp + fn)
        fpr_d = (fp + tn)
        if tpr_d == 0:
            tpr_d += offset
        if fpr_d == 0:
            fpr_d += offset
        tprs[i] = tp / tpr_d
        fprs[i] = fp / fpr_d
    return tprs, fprs


# Equalized Odds
def odd_diffs_binary(y, y_pred, s):
    y0, y1 = y[s == 0], y[s == 1]
    y_pred0, y_pred1 = y_pred[s == 0], y_pred[s == 1]

    tpr0, fpr0 = confusion(y0, y_pred0)
    tpr1, fpr1 = confusion(y1, y_pred1)

    tpr_diff = tpr1 - tpr0
    fpr_diff = fpr1 - fpr0

    return (np.abs(tpr_diff) + np.abs(fpr_diff)) / 2



def sparsity_eo(y_true, y_pred, groups, m_sparsity='pqi', pos=False, **kwargs):
    """
    Calculate sparsity measure for equal opportunity metrics.
    Uses caching and vectorized operations for better performance.
    """
    if y_true.dtype == np.int64:
        metric_name = kwargs.get('metric_name', 'tpr_fpr')
    elif y_true.dtype == np.float32:
        metric_name = kwargs.get('metric_name', 'mse')
    else:
        raise ValueError('target type not implemented')
    p = kwargs.get('p', 1)
    q = kwargs.get('q', 2)

    if m_sparsity not in {'pqi', 'gini'}:
        raise ValueError('Sparsity measure not implemented')

    # Get metrics matrix/matrices
    if metric_name == 'tpr_fpr':
        tprs, fprs = get_metrics(y_true, y_pred, groups, metric_name=metric_name)
        if pos:
            tprs = np.exp(tprs)
            fprs = np.exp(fprs)
        
        tpr_spar = np.array([
            cal_pqi(grp, p, q) if m_sparsity == 'pqi' else cal_gini(grp)
            for grp in tprs
        ])
        fpr_spar = np.array([
            cal_pqi(grp, p, q) if m_sparsity == 'pqi' else cal_gini(grp)
            for grp in fprs
        ])
        return np.max((tpr_spar + fpr_spar) / 2)
    

    else:
        met_mat = get_metrics(y_true, y_pred, groups, metric_name=metric_name)
        if pos:
            met_mat = np.exp(met_mat)
        
        mspar = np.array([
            cal_pqi(grp, p, q) if m_sparsity == 'pqi' else cal_gini(grp)
            for grp in met_mat
        ])

        return np.max(mspar)





def odds_diffs_mean(y_true, y_pred, groups, ns, nc):
    # ns: number of groups
    # nc: number of classes
    tprs, fprs = get_metrics(y_true, y_pred, groups, metric_name='tpr_fpr')
    tpr_diff, fpr_diff = np.zeros((nc, ns * (ns - 1) // 2)), np.zeros((nc, ns * (ns - 1) // 2))
    for i in range(nc):
        tpr_diff[i, :] = np.array([a1 - a2 for (a1, a2) in combinations(tprs[i, :], 2)])
        fpr_diff[i, :] = np.array([a1 - a2 for (a1, a2) in combinations(fprs[i, :], 2)])
    # （abs(tpr_diff) + abs(fpr_diff)) / 2 
    meo = (np.abs(tpr_diff) + np.abs(fpr_diff)) / 2
    # get pairwise maximum
    meo = np.max(meo, axis = 1)
    return meo



def cal_gini(array):
    """Calculate the Gini coefficient of a numpy array."""
    # based on bottom eq:
    # http://www.statsdirect.com/help/generatedimages/equations/equation154.svg
    # from:
    # http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm
    # All values are treated equally, arrays must be 1d:
    array = np.array(array).flatten()
    if np.amin(array) < 0:
        # Values cannot be negative:
        array -= np.amin(array)
    # Values cannot be 0:
    array += 0.0000001
    # Values must be sorted:
    array = np.sort(array)
    # Index per array element:
    index = np.arange(1, array.shape[0] + 1)
    # Number of array elements:
    n = array.shape[0]
    # Gini coefficient:
    return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))


def cal_pqi(x, p=1, q=2):
    d = len(x)
    x = np.array(x)
    # pre check if all the values are equal, avoid all zeros
    if np.all(x == x[0]):
        return 0
    
    si = 1 - (np.linalg.norm(x, ord=p, axis=0)**p / d)**(1 / p) / \
         (np.linalg.norm(x, ord=q, axis=0)**q / d)**(1 / q)

    if si == -np.inf or np.logical_and(si > - 1e-5, si < 0):
        si = 0

    return si


def mclass_spspar(rate, m_sparsity='pqi', **kwargs):
    class_metrics = []
    for i in range(rate.shape[0]):  # loop classes
        avg_vec = rate[i, :].flatten()
        if m_sparsity == 'pqi':
            p = kwargs.get('p', 1)
            q = kwargs.get('q', 2)
            sparsity_vec = cal_pqi(avg_vec, p=p, q=q)   
        elif m_sparsity == 'gini':
            sparsity_vec = cal_gini(avg_vec)
        else:
            raise 'sparsity measure not implemented'
        class_metrics.append(sparsity_vec)
    return np.max(class_metrics)


def get_metrics(y_true, y_pred, sens, **kwargs):
    
    group = np.unique(sens)
    if y_true.dtype == np.int64:
        cls = np.unique(y_true)
        nc = len(cls)

    ns = len(group)

    
    if kwargs.get('metric_name') == 'rate':
        rate = np.zeros((nc, ns))
        # get rate matrix
        # rate(y,s) = P(y_pred = y | sens = s) = P(y_pred = y and sens = s) / P(sens = s)
        for i in range(ns):
            for j in range(nc):
                if len(y_pred[sens == group[i]]) != 0:
                    rate[j, i] = len(y_pred[np.logical_and(sens == group[i], y_pred == cls[j])]) / len(y_pred[sens == group[i]])
        return rate
    
    elif kwargs.get('metric_name') == 'loss':
        loss_mat = np.zeros((nc, ns))
        # check if y_pred is probability
        if y_pred.dtype == np.float32:
        # it is scores, separate loss by group
            for s in range(ns):
                for c in range(nc):
                    loss_s = y_pred[(sens == group[s]) & (y_true == cls[c])]
                    loss_mat[c, s] = np.mean(loss_s)

                
        else:
            raise 'y_pred must be np.float32'
        return loss_mat

            
    elif kwargs.get('metric_name') == 'accuracy':
        acc = np.zeros((nc, ns))
        # get accuracy matrix
        for i in range(ns):
            y_s = y_true[sens == group[i]]
            y_pred_s = y_pred[sens == group[i]]
            acc[:, i] = np.array([np.mean(y_s == y_pred_s)] * nc)
        return acc
    
    elif kwargs.get('metric_name') == 'tpr_fpr':    
        # group wise tpr and fpr
        tpr, fpr = np.zeros((nc, ns)), np.zeros((nc, ns))
        for i in range(ns):
            y_s = y_true[sens == group[i]]
            y_pred_s = y_pred[sens == group[i]]
            # return tpr and fpr for each class
            tpr[:, i], fpr[:, i] = multiclass_confusion(y_s, y_pred_s, nc)
        
        return tpr, fpr

    elif kwargs.get('metric_name') == 'f1':
        f1 = np.zeros((nc, ns))    
        # get f1 score matrix
        for i in range(ns):
            y_s = y_true[sens == group[i]]
            y_pred_s = y_pred[sens == group[i]]
            f1[:, i] = f1_score(y_s, y_pred_s, average=None)
        return f1
    
    elif kwargs.get('metric_name') == 'auroc':
        roc_auc = np.zeros((nc, ns))
        for i in range(ns):
            y_s = y_true[sens == group[i]]
            y_pred_s = y_pred[sens == group[i]]
            try: # in case only single class are present
                roc_score = roc_auc_score(y_s, y_pred_s, average=None)
            except:
                roc_score = 0
            roc_auc[:, i] = roc_score
        return roc_auc
    
    # metric for regression
    elif kwargs.get('metric_name') == 'mae':
        mae = np.zeros((1, ns))
        for i in range(ns):
            y_s = y_true[sens == group[i]]
            y_pred_s = y_pred[sens == group[i]]
            mae[0, i] = np.mean(np.abs(y_s - y_pred_s))
        return mae

    elif kwargs.get('metric_name') == 'rmse':
        rmse = np.zeros((1, ns))
        for i in range(ns):
            y_s = y_true[sens == group[i]]
            y_pred_s = y_pred[sens == group[i]]
            rmse[0, i] = np.sqrt(np.mean((y_s - y_pred_s) ** 2))
        return rmse
    
    elif kwargs.get('metric_name') == 'mse':
        mse = np.zeros((1, ns))
        for i in range(ns):
            y_s = y_true[sens == group[i]]
            y_pred_s = y_pred[sens == group[i]]
            mse[0, i] = np.mean((y_s - y_pred_s) ** 2)
        return mse
    
    elif kwargs.get('metric_name') == 'r2':
        r2 = np.zeros((1, ns))
        for i in range(ns):
            y_s = y_true[sens == group[i]]
            y_pred_s = y_pred[sens == group[i]]
            r2[0, i] = 1 - np.sum((y_s - y_pred_s) ** 2) / np.sum((y_s - np.mean(y_s)) ** 2)
        return r2
    else:
        raise 'metric not implemented'


def get_distribution_params(output, target, sens, condition='sp'):
    group_params = {}
    global_params = {}
    if target.dtype == np.int64:  # classification problem
        if condition == 'eo':
            # get q(yhat|y, s) and q(yhat|y)
            tclass = np.unique(output)
            for c in tclass:
                ypred_c = output[target == c]
                _, counts = np.unique(ypred_c, return_counts=True)
                global_params[c] = {}
                global_params[c]['dist'] = counts / np.sum(counts)
                for s in np.unique(sens):
                    ypred_c_s = ypred_c[sens[target == c] == s]
                    _, counts = np.unique(ypred_c_s, return_counts=True)
                    group_dist = counts / np.sum(counts)
                    group_params[(c, s)] = {}
                    group_params[(c, s)]['dist'] = group_dist
        elif condition == 'sp':
            # convert to probability
            _, counts = np.unique(output, return_counts=True)
            global_params['dist'] = counts / np.sum(counts)
            for s in np.unique(sens):
                ypred_s = output[sens == s]
                _, counts_s = np.unique(ypred_s, return_counts=True)
                group_dist = counts_s / np.sum(counts_s)
                group_params[s] = {}
                group_params[s]['dist'] = group_dist
    elif target.dtype == np.float32:  # regression problem
        # get distribution of ypred conditional on y and s
        # global distribution
        global_mu = np.mean(output)
        global_sigma = np.sqrt(np.var(output))
        global_params['mu'] = global_mu
        global_params['sigma'] = global_sigma
        global_params['rss'] = np.sum((output - target) ** 2)
        # p(yhat|s)
        for s in np.unique(sens):
            output_s = output[sens == s]
            target_s = target[sens == s]
            group_mu = np.mean(output_s)
            # group_sigma = np.sqrt(np.sum((output_s - group_mu)**2)/ (len(output_s)-(cfg['data_shape'][0]+1)))
            group_sigma = np.sqrt(np.var(output_s))
            group_params[s] = {}
            group_params[s]['mu'] = group_mu
            group_params[s]['sigma'] = group_sigma
            group_params[s]['rss'] = np.sum((output_s - target_s) ** 2)  # for calculate EO

    else:
        raise 'target type not implemented'
    return group_params, global_params


def gaussian_nll(x, mu, sigma):
    """ 
    log likelihood of gaussian distribution, used for EO (dist) for regression
    """
    if sigma == 0:
        sigma = 1e-6
    
    nll = 1 / 2 * np.log(2 * np.pi * sigma ** 2) + (x - mu) ** 2 / (2 * sigma ** 2)
    return np.mean(nll)

def ks_dist(scores, groups):
    """Maximum pairwise KS distance"""
    n_groups = len(np.unique(groups))
    max_ks = 0
    for i in range(n_groups):
        for j in range(i + 1, n_groups):
            max_ks = max(max_ks,
                        stats.ks_2samp(scores[groups == i], scores[groups == j]).statistic)
    return max_ks

def w1_dist(scores, groups):
    """Wasserstein-1 distance"""
    n_groups = len(np.unique(groups))
    w2 = []
    for i in range(n_groups):
        for j in range(i + 1, n_groups):
            w2.append(stats.wasserstein_distance(scores[groups == i], scores[groups == j]))
    return np.max(w2)

def cal_gini_pair(output, sensitive):
    cdfs = {}
    sorted_output = np.sort(output, kind='mergesort')

    for s in np.unique(sensitive):
        output_s = output[sensitive == s]
        sorted_s = np.argsort(output_s)
        cdf_indices = output_s[sorted_s].searchsorted(sorted_output[:-1], 'right')

        cdfs[s] = cdf_indices / output_s.size
    # calculate the pointwise sparsity for each group
    gini_vec = []
    for n in range(len(output)-1):
        point_ecdf = []
        for s in cdfs:
            point_ecdf.append(cdfs[s][n])
        gini_vec.append(cal_gini(point_ecdf))
    return gini_vec

def cal_pqi_pair(output, sensitive, p, q):
    # implementation in scipy
    # https://github.com/scipy/scipy/blob/v1.14.1/scipy/stats/_stats_py.py#L10210
    cdfs = {}
    sorted_output = np.sort(output, kind='mergesort')

    for s in np.unique(sensitive):
        output_s = output[sensitive == s]
        sorted_s = np.argsort(output_s)
        cdf_indices = output_s[sorted_s].searchsorted(sorted_output[:-1], 'right')

        cdfs[s] = cdf_indices / output_s.size
    # calculate the pointwise sparsity for each group
    pqi_vec = []
    for n in range(len(output)-1):
        point_ecdf = []
        for s in cdfs:
            point_ecdf.append(cdfs[s][n])

        pqi_vec.append(cal_pqi(point_ecdf, p, q))
    return pqi_vec
