import numpy as np
import scipy.sparse as sp
import scipy.io
import torch
import os
import os.path as osp

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.utils import train_test_split_edges
import sys
import pandas as pd
import dgl
from torch_geometric.utils import from_scipy_sparse_matrix
import functools
import networkx as nx
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import normalize as nor_sp
from scipy.spatial import distance_matrix
from sklearn.preprocessing import  OneHotEncoder

def fair_metric(output, labels, sens):
    val_y = labels
    idx_s0 = sens.cpu().numpy()==0
    idx_s1 = sens.cpu().numpy()==1
    idx_s0_y1 = np.bitwise_and(idx_s0,val_y==1)
    idx_s1_y1 = np.bitwise_and(idx_s1,val_y==1)

    pred_y = output
    parity = abs(sum(pred_y[idx_s0])/sum(idx_s0)-sum(pred_y[idx_s1])/sum(idx_s1))
    equality = abs(sum(pred_y[idx_s0_y1])/sum(idx_s0_y1)-sum(pred_y[idx_s1_y1])/sum(idx_s1_y1))

    return parity,equality
def fair_metric_mc(output, labels, sens):
    val_y = labels
    idx_s0 = sens.cpu().numpy()==0
    idx_s1 = sens.cpu().numpy()==1

    pred_y = output
    
    parity =abs((len(np.where(pred_y[idx_s0]!=val_y[idx_s0])[0])/len(idx_s0))-(len(np.where(pred_y[idx_s1]!=val_y[idx_s1])[0])/len(idx_s1)))

    return parity
def maximize_over_t(inter,intra):
    t=np.arange(0,1,0.05)
    cur_max=0
    optimized_t=0
    for i,val in enumerate(t):
        cand=np.absolute(len(np.where(inter < val)[0])/len(inter)-len(np.where(intra < val)[0])/len(intra))
        if cand>cur_max:
            cur_max=cand
            optimized_t=val
    return cur_max
def repeat(n_times):
    def decorator(f):
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            results = [f(*args, **kwargs) for _ in range(n_times)]
            statistics = {}
            for key in results[0].keys():
                values = [r[key] for r in results]
                statistics[key] = {
                    'mean': np.mean(values),
                    'std': np.std(values)}
            print_statistics(statistics, f.__name__)
            return statistics
        return wrapper
    return decorator

def prob_to_one_hot(y_pred):
    ret = np.zeros(y_pred.shape, np.bool)
    indices = np.argmax(y_pred, axis=1)
    for i in range(y_pred.shape[0]):
        ret[i][indices[i]] = True
    return ret


def print_statistics(statistics, function_name):
    print(f'(E) | {function_name}:', end=' ')
    for i, key in enumerate(statistics.keys()):
        mean = statistics[key]['mean']
        std = statistics[key]['std']
        print(f'{key}={mean:.4f}+-{std:.4f}', end='')
        if i != len(statistics.keys()) - 1:
            print(',', end=' ')
        else:
            print()

@repeat(3)
def label_classification(embeddings, y, sens, ratio):
    X = embeddings.detach().cpu().numpy()
    Y = y.detach().cpu().numpy()
    Y = Y.reshape(-1, 1)
    onehot_encoder = OneHotEncoder(categories='auto').fit(Y)
    Y = onehot_encoder.transform(Y).toarray().astype(np.bool)

    X = nor_sp(X, norm='l2')
    indices=range(np.shape(X)[0])
    X_train, X_test, y_train, y_test,indices_train,indices_test = train_test_split(X, Y,indices,
                                                                                   test_size=1 - ratio)

    logreg = LogisticRegression(solver='liblinear')
    c = 2.0 ** np.arange(-10, 10)

    clf = GridSearchCV(estimator=OneVsRestClassifier(logreg),
                       param_grid=dict(estimator__C=c), n_jobs=5, cv=5,
                       verbose=0)
    clf.fit(X_train, y_train)

    y_pred = clf.predict_proba(X_test)
    y_pred = prob_to_one_hot(y_pred)

    micro = f1_score(y_test, y_pred, average="micro")
    macro = f1_score(y_test, y_pred, average="macro")
    acc=accuracy_score(y_test, y_pred)
    roc_auc=roc_auc_score(y_test, y_pred)

    if np.shape(y_pred)[1]>2:
        acc_parity=fair_metric_mc(np.argmax(y_pred,axis=1),np.argmax(y_test,axis=1),sens[indices_test])
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'acc_parity': acc_parity
        }
    else:
        parity,equality=fair_metric(np.argmax(y_pred,axis=1),np.argmax(y_test,axis=1),sens[indices_test])
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'parity': parity,
            'equality': equality
        }
@repeat(3)
def sens_classification(embeddings, y, ratio):
    X = embeddings.detach().cpu().numpy()
    Y = y.detach().cpu().numpy()
    Y = Y.reshape(-1, 1)
    onehot_encoder = OneHotEncoder(categories='auto').fit(Y)
    Y = onehot_encoder.transform(Y).toarray().astype(np.bool)

    X = nor_sp(X, norm='l2')
    indices=range(np.shape(X)[0])
    X_train, X_test, y_train, y_test,indices_train,indices_test = train_test_split(X, Y,indices,
                                                                                   test_size=1 - ratio)

    logreg = LogisticRegression(solver='liblinear')
    c = 2.0 ** np.arange(-10, 10)

    clf = GridSearchCV(estimator=OneVsRestClassifier(logreg),
                       param_grid=dict(estimator__C=c), n_jobs=5, cv=5,
                       verbose=0)
    clf.fit(X_train, y_train)

    y_pred = clf.predict_proba(X_test)
    y_pred = prob_to_one_hot(y_pred)

    micro = f1_score(y_test, y_pred, average="micro")
    macro = f1_score(y_test, y_pred, average="macro")
    acc=accuracy_score(y_test, y_pred)
    roc_auc=roc_auc_score(y_test, y_pred)
    rb=0
    all_s=len(np.argmax(y_test,axis=1))
    for i,s in enumerate(np.unique(y.detach().cpu().numpy())):
        ind_s=np.where(np.argmax(y_pred,axis=1)==s)[0]
        acc=accuracy_score(y_test[ind_s], y_pred[ind_s])
        rb=rb+(float(len(ind_s))/np.shape(y_pred)[0])*acc

    return {'rb' : rb,
            'roc_auc' : roc_auc,
            'accuracy' : acc}


def load_pokec(dataset, sens_attr, predict_attr, path="../pokec/pokec_dataset/", tris=False, degs=False):
    """Load data"""
    print('Loading {} dataset from {}'.format(dataset, path))

    idx_features_labels = pd.read_csv(os.path.join(path, "{}.csv".format(dataset)))

    header = list(idx_features_labels.columns)
    header.remove("user_id")
    header.remove(sens_attr)
    header.remove(predict_attr)

    features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32)
    labels = idx_features_labels[predict_attr].values
    sens = idx_features_labels[sens_attr].values

    sens_idx = set(np.where(sens >= 0)[0])
    label_idx = np.where(labels >= 0)[0]
    idx_used = np.asarray(list(sens_idx & set(label_idx)))
    idx_nonused = np.asarray(list(set(np.arange(len(labels))).difference(set(idx_used))))

    features = features[idx_used, :]
    labels = labels[idx_used]
    sens = sens[idx_used]

    idx = np.array(idx_features_labels["user_id"], dtype=int)
    edges_unordered = np.genfromtxt(os.path.join(path, "{}_relationship.txt".format(dataset)), dtype=int)

    idx_n = idx[idx_nonused]
    idx = idx[idx_used]
    used_ind1 = [i for i, elem in enumerate(edges_unordered[:, 0]) if elem not in idx_n]
    used_ind2 = [i for i, elem in enumerate(edges_unordered[:, 1]) if elem not in idx_n]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges_unordered = edges_unordered[intersect_ind, :]
    # build graph

    idx_map = {j: i for i, j in enumerate(idx)}
    edges_un = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                        dtype=int).reshape(edges_unordered.shape)

    ####
    adj = sp.coo_matrix((np.ones(edges_un.shape[0]), (edges_un[:, 0], edges_un[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)

    # build symmetric adjacency matrix
   
    G = nx.from_scipy_sparse_matrix(adj)
    g_nx_ccs = (G.subgraph(c).copy() for c in nx.connected_components(G))
    g_nx = max(g_nx_ccs, key=len)
    import random
    seed=19
    random.seed(seed)
    node_ids = list(g_nx.nodes())
    idx_s=node_ids
    random.shuffle(idx_s)
    

    sens=sens[idx_s]
    sens[sens > 0] = 1 
    features=features[idx_s,:]
    features=features[:,np.where(np.std(np.array(features.todense()),axis=0)!=0)[0]]
    features =np.array(features.todense())
    #Special to this study sensitive attributes are concatenated to the nodal features.
    features=torch.FloatTensor(np.concatenate((np.reshape(sens,(len(sens),1)),features),axis=1))
        
    labels=torch.LongTensor(labels[idx_s])
    sens=torch.LongTensor(sens) 
    
    labels[labels > 1] = 1
    sens[sens > 0] = 1

    idx_map_n = {j: int(i) for i, j in enumerate(idx_s)}
    idx_nonused2 = np.asarray(list(set(np.arange(len(list(G.nodes())))).difference(set(idx_s))))
    
    used_ind1 = [i for i, elem in enumerate(edges_un[:, 0]) if elem not in idx_nonused2]
    used_ind2 = [i for i, elem in enumerate(edges_un[:, 1]) if elem not in idx_nonused2]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges_un = edges_un[intersect_ind, :]
    edges = np.array(list(map(idx_map_n.get, edges_un.flatten())),
                     dtype=int).reshape(edges_un.shape)
    edges=np.unique(edges, axis=0)
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    adj = adj + sp.eye(adj.shape[0])
    edges= np.concatenate((np.reshape(scipy.sparse.find(adj)[0],(len(scipy.sparse.find(adj)[0]),1)),np.reshape(scipy.sparse.find(adj)[1],(len(scipy.sparse.find(adj)[1]),1))),axis=1)
    edges = torch.LongTensor(edges.T)    
    return edges, adj, features, labels, sens

def feature_norm(features):
    feat_mean=torch.mean(features,0)
    feat_std=torch.std(features,0)
    return (features- feat_mean)/feat_std


def encode_onehot(labels):
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                    enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def build_relationship(x, thresh=0.25):
    df_euclid = pd.DataFrame(1 / (1 + distance_matrix(x.T.T, x.T.T)), columns=x.T.columns, index=x.T.columns)
    df_euclid = df_euclid.to_numpy()
    idx_map = []
    for ind in range(df_euclid.shape[0]):
        max_sim = np.sort(df_euclid[ind, :])[-2]
        neig_id = np.where(df_euclid[ind, :] > thresh*max_sim)[0]
        import random
        random.seed(912)
        random.shuffle(neig_id)
        for neig in neig_id:
            if neig != ind:
                idx_map.append([ind, neig])
    # print('building edge relationship complete')
    idx_map =  np.array(idx_map)
    
    return idx_map


def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def accuracy(output, labels):
    output = output.squeeze()
    preds = (output>0).type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

def accuracy_softmax(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)
