import numpy as np
import scipy.sparse as sp
import torch
import os
import sys
import pandas as pd
import dgl
from torch_geometric.utils import from_scipy_sparse_matrix
import functools
import networkx as nx
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import normalize, OneHotEncoder


def fair_metric(output, labels, sens):
    val_y = labels
    idx_s0 = sens.cpu().numpy()==0
    idx_s1 = sens.cpu().numpy()==1
    idx_s0_y1 = np.bitwise_and(idx_s0,val_y==1)
    idx_s1_y1 = np.bitwise_and(idx_s1,val_y==1)

    pred_y = output

    parity = abs(sum(pred_y[idx_s0])/sum(idx_s0)-sum(pred_y[idx_s1])/sum(idx_s1))
    equality = abs(sum(pred_y[idx_s0_y1])/sum(idx_s0_y1)-sum(pred_y[idx_s1_y1])/sum(idx_s1_y1))

    return parity,equality

def fair_metric_mc(output, labels, sens):
    val_y = labels
    idx_s0 = sens.cpu().numpy()==0
    idx_s1 = sens.cpu().numpy()==1

    pred_y = output
    
    parity =abs((len(np.where(pred_y[idx_s0]!=val_y[idx_s0])[0])/len(idx_s0))-(len(np.where(pred_y[idx_s1]!=val_y[idx_s1])[0])/len(idx_s1)))

    return parity

def repeat(n_times):
    def decorator(f):
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            results = [f(*args, **kwargs) for _ in range(n_times)]
            statistics = {}
            for key in results[0].keys():
                values = [r[key] for r in results]
                statistics[key] = {
                    'mean': np.mean(values),
                    'std': np.std(values)}
            print_statistics(statistics, f.__name__)
            return statistics
        return wrapper
    return decorator

def prob_to_one_hot(y_pred):
    ret = np.zeros(y_pred.shape, np.bool)
    indices = np.argmax(y_pred, axis=1)
    for i in range(y_pred.shape[0]):
        ret[i][indices[i]] = True
    return ret


def print_statistics(statistics, function_name):
    print(f'(E) | {function_name}:', end=' ')
    for i, key in enumerate(statistics.keys()):
        mean = statistics[key]['mean']
        std = statistics[key]['std']
        print(f'{key}={mean:.4f}+-{std:.4f}', end='')
        if i != len(statistics.keys()) - 1:
            print(',', end=' ')
        else:
            print()



@repeat(3)
def label_classification(embeddings, y, sens, ratio):
    X = embeddings.detach().cpu().numpy()
    Y = y.detach().cpu().numpy()
    Y = Y.reshape(-1, 1)
    onehot_encoder = OneHotEncoder(categories='auto').fit(Y)
    Y = onehot_encoder.transform(Y).toarray().astype(np.bool)

    X = normalize(X, norm='l2')
    indices=range(np.shape(X)[0])
    X_train, X_test, y_train, y_test,indices_train,indices_test = train_test_split(X, Y,indices,
                                                                                   test_size=1 - ratio)

    logreg = LogisticRegression(solver='liblinear')
    c = 2.0 ** np.arange(-10, 10)

    clf = GridSearchCV(estimator=OneVsRestClassifier(logreg),
                       param_grid=dict(estimator__C=c), n_jobs=5, cv=5,
                       verbose=0)
    clf.fit(X_train, y_train)

    y_pred = clf.predict_proba(X_test)
    y_pred = prob_to_one_hot(y_pred)

    micro = f1_score(y_test, y_pred, average="micro")
    macro = f1_score(y_test, y_pred, average="macro")
    acc=accuracy_score(y_test, y_pred)
    roc_auc=roc_auc_score(y_test, y_pred)

    if np.shape(y_pred)[1]>2:
        acc_parity=fair_metric_mc(np.argmax(y_pred,axis=1),np.argmax(y_test,axis=1),sens[indices_test])
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'acc_parity': acc_parity
        }
    else:
        parity,equality=fair_metric(np.argmax(y_pred,axis=1),np.argmax(y_test,axis=1),sens[indices_test])
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'parity': parity,
            'equality': equality
        }

@repeat(3)
def link_prediction(embeddings, edges_tr, edges_t, sens):
    X = embeddings.detach().cpu().numpy()
    edges_tr = edges_tr.detach().cpu().numpy().T
    #edges_val = edges_val.detach().cpu().numpy().T
    edges_t = edges_t.detach().cpu().numpy().T
    
    #edges_all = np.concatenate((np.concatenate((edges_tr,edges_val),axis=0),edges_t),axis=0)
    edges_all=np.concatenate((edges_tr,edges_t),axis=0)
    
    adj = sp.coo_matrix((np.ones(edges_all.shape[0]), (edges_all[:, 0], edges_all[:, 1])),
                            shape=(sens.shape[0], sens.shape[0]),
                            dtype=np.float32)
    adj=np.array(adj.toarray())
    candids1=np.where(adj==0)[0]
    candids2=np.where(adj==0)[1]
    
    neg_edges=np.concatenate((np.reshape(candids1,(len(candids1),1)),np.reshape(candids2,(len(candids2),1))),axis=1)
    
    same_idx=np.where((sens[candids1]==sens[candids2])==True)[0]
    diff_idx=np.where((sens[candids1]!=sens[candids2])==True)[0]
   
    diff_neg_tr=diff_idx[torch.multinomial(torch.full([len(diff_idx)], 1/float(len(diff_idx))),len(np.where((sens[edges_tr[:,0]] != sens[edges_tr[:,1]]) == True)[0]), replacement=False, generator=None, out=None)]
    
    
    
    same_neg_tr=same_idx[torch.multinomial(torch.full([len(same_idx)], 1/float(len(same_idx))),len(np.where((sens[edges_tr[:,0]] == sens[edges_tr[:,1]]) == True)[0]), replacement=False, generator=None, out=None)]
    neg_edges_tr=neg_edges[np.unique(np.concatenate((diff_neg_tr,same_neg_tr))),:]
    
    #diff_neg_val=diff_idx[torch.multinomial(torch.full([len(diff_idx)], 1/float(len(diff_idx))), len(np.where((sens[edges_val[:,0]] != sens[edges_val[:,1]]) == True)[0]) replacement=False, generator=None, out=None)]
    #same_neg_val=same_idx[torch.multinomial(torch.full([len(same_idx)], 1/float(len(same_idx))), len(np.where((sens[edges_tr[:,0]] == sens[edges_tr[:,1]]) == True)[0]) replacement=False, generator=None, out=None)]
    #neg_edges_val=neg_edges[np.unique(np.concatenate((diff_neg_val,same_neg_val))),:] 

    diff_neg_t=diff_idx[torch.multinomial(torch.full([len(diff_idx)], 1/float(len(diff_idx))), len(np.where((sens[edges_t[:,0]] != sens[edges_t[:,1]]) == True)[0]), replacement=False, generator=None, out=None)]
    same_neg_t=same_idx[torch.multinomial(torch.full([len(same_idx)], 1/float(len(same_idx))), len(np.where((sens[edges_t[:,0]] == sens[edges_t[:,1]]) == True)[0]), replacement=False, generator=None, out=None)]
    neg_edges_t=neg_edges[np.unique(np.concatenate((diff_neg_t,same_neg_t))),:]         
        
        
       
    
    X = normalize(X, norm='l2')
    
    X_tr=np.concatenate((X[edges_tr[:,0]],X[edges_tr[:,1]]),axis=1)
    y_tr=np.ones(np.shape(X_tr)[0])
    sens_tr=np.zeros(np.shape(X_tr)[0])
    sens_tr[np.where((sens[edges_tr[:,0]] != sens[edges_tr[:,1]]) == True)[0]]=1
    X_neg_tr=np.concatenate((X[neg_edges_tr[:,0]],X[neg_edges_tr[:,1]]),axis=1)
    y_neg_tr=np.zeros(np.shape(X_neg_tr)[0])
    sens_neg_tr=np.zeros(np.shape(X_neg_tr)[0])
    sens_neg_tr[np.where((sens[neg_edges_tr[:,0]] != sens[neg_edges_tr[:,1]]) == True)[0]]=1
    
    X_all_tr=np.concatenate((X_tr,X_neg_tr),axis=0)
    y_all_tr=np.concatenate((y_tr,y_neg_tr),axis=0)
    sens_all_tr=np.concatenate((sens_tr,sens_neg_tr),axis=0)
    
    indices_tr = np.arange(np.shape(X_all_tr)[0])
    import random
    seed=19
    random.seed(seed)
    random.shuffle(indices_tr)
    
    X_all_tr=X_all_tr[indices_tr,:]
    y_all_tr=y_all_tr[indices_tr]
    sens_all_tr=sens_all_tr[indices_tr]     
    
    y_all_tr=y_all_tr.reshape(-1, 1)    
    onehot_encoder = OneHotEncoder(categories='auto').fit(y_all_tr)
    Y_all_tr = onehot_encoder.transform(y_all_tr).toarray().astype(np.bool)
    
    #X_val=np.concatenate((X[edges_val[:,0]],X[edges_val[:,1]]),axis=1)
    #y_val=np.ones(np.shape(X_val)[0])
    #sens_val=np.zeros(np.shape(X_val)[0])
    #sens_val[np.where((sens[edges_val[:,0]] != sens[edges_val[:,1]]) == True)[0]]=1
    #X_neg_val=np.concatenate((X[neg_edges_val[:,0]],X[neg_edges_val[:,1]]),axis=1)
    #y_neg_val=np.zeros(np.shape(X_neg_val)[0])
    #sens_neg_val=np.zeros(np.shape(X_neg_val)[0])
    #sens_neg_val[np.where((sens[neg_edges_val[:,0]] != sens[neg_edges_val[:,1]]) == True)[0]]=1
    
    #X_all_val=np.concatenate((X_val,X_neg_val),axis=0)
    #y_all_val=np.concatenate((y_val,y_neg_val),axis=0)
    #sens_all_val=np.concatenate((sens_val,sens_neg_val),axis=0)
    
    #indices_val = np.arange(np.shape(X_all_val)[0])
    #import random
    #seed=19
    #random.seed(seed)
    #random.shuffle(indices_val)
    
    #X_all_val=X_all_val[indices_val,:]
    #y_all_val=y_all_val[indices_val,:]
    #sens_all_val=sens_all_val[indices_val,:]  

    X_t=np.concatenate((X[edges_t[:,0]],X[edges_t[:,1]]),axis=1)
    y_t=np.ones(np.shape(X_t)[0])
    sens_t=np.zeros(np.shape(X_t)[0])
    sens_t[np.where((sens[edges_t[:,0]] != sens[edges_t[:,1]]) == True)[0]]=1
    X_neg_t=np.concatenate((X[neg_edges_t[:,0]],X[neg_edges_t[:,1]]),axis=1)
    y_neg_t=np.zeros(np.shape(X_neg_t)[0])
    sens_neg_t=np.zeros(np.shape(X_neg_t)[0])
    sens_neg_t[np.where((sens[neg_edges_t[:,0]] != sens[neg_edges_t[:,1]]) == True)[0]]=1
    
    X_all_t=np.concatenate((X_t,X_neg_t),axis=0)
    y_all_t=np.concatenate((y_t,y_neg_t),axis=0)
    sens_all_t=np.concatenate((sens_t,sens_neg_t),axis=0)
    
    indices_t = np.arange(np.shape(X_all_t)[0])
    import random
    seed=19
    random.seed(seed)
    random.shuffle(indices_t)
    
    X_all_t=X_all_t[indices_t,:]
    y_all_t=y_all_t[indices_t]
    sens_all_t=sens_all_t[indices_t] 
    
    idx_y1=np.where(y_all_t==1)[0]
    idx_y0=np.where(y_all_t==0)[0]
    idx_inter=np.where(sens_all_t==1)[0]
    idx_intra=np.where(sens_all_t==0)[0]
    
    inter_idx_y1=np.array(list(set(idx_y1) & set(idx_inter)))
    intra_idx_y1=np.array(list(set(idx_y1) & set(idx_intra)))
    
    inter_idx_y0=np.array(list(set(idx_y0) & set(idx_inter)))
    intra_idx_y0=np.array(list(set(idx_y0) & set(idx_intra)))
    
    y_all_t=y_all_t.reshape(-1, 1)
    onehot_encoder = OneHotEncoder(categories='auto').fit(y_all_t)
    Y_all_t = onehot_encoder.transform(y_all_t).toarray().astype(np.bool)
    
   
    
    logreg = LogisticRegression(solver='liblinear')
    c = 2.0 ** np.arange(-10, 10)

    clf = GridSearchCV(estimator=OneVsRestClassifier(logreg),
                       param_grid=dict(estimator__C=c), n_jobs=5, cv=5,
                       verbose=0)
    clf.fit(X_all_tr, Y_all_tr)

    y_pred = clf.predict_proba(X_all_t)
    
    score_inter_y1=y_pred[inter_idx_y1,1]
    score_intra_y1=y_pred[intra_idx_y1,1]
    
    y_pred = prob_to_one_hot(y_pred)
    
    delta_true=np.absolute(accuracy_score(Y_all_t[intra_idx_y1,:], y_pred[intra_idx_y1,:])-accuracy_score(Y_all_t[inter_idx_y1,:], y_pred[inter_idx_y1,:]))
    delta_false=np.absolute(accuracy_score(Y_all_t[intra_idx_y0,:], y_pred[intra_idx_y0,:])-accuracy_score(Y_all_t[inter_idx_y0,:], y_pred[inter_idx_y0,:]))
    delta_fnr=maximize_over_t(score_inter_y1, score_intra_y1)
    delta_tnr=maximize_over_t(score_inter_y0, score_intra_y0)
    
    micro = f1_score(Y_all_t, y_pred, average="micro")
    macro = f1_score(Y_all_t, y_pred, average="macro")
    acc=accuracy_score(Y_all_t, y_pred)
    roc_auc=roc_auc_score(Y_all_t, y_pred)

    if np.shape(y_pred)[1]>2:
        acc_parity=fair_metric_mc(np.argmax(y_pred,axis=1),np.argmax(Y_all_t,axis=1),sens_all_t)
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'acc_parity': acc_parity
        }
    else:
        parity,equality=fair_metric(np.argmax(y_pred,axis=1),np.argmax(Y_all_t,axis=1),sens_all_t)
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'parity': parity,
            'equality': equality,
            'delta_true': delta_true,
            'delta_false': delta_false,
            'delta_fnr': delta_fnr,
            'delta_tnr':delta_tnr
            
        }
def maximize_over_t(inter,intra):
    t=np.arange(0,1,0.05)
    cur_max=0
    optimized_t=0
    for i,val in enumerate(t):
        cand=np.absolute(len(np.where(inter < val)[0])/len(inter)-len(np.where(intra < val)[0])/len(intra))
        if cand>cur_max:
            cur_max=cand
            optimized_t=val
    return cur_max, optimized_t