import numpy as np
import scipy.sparse as sp
import scipy.io
import torch
import os
import os.path as osp
import sys
import pandas as pd
import dgl
from torch_geometric.utils import from_scipy_sparse_matrix
import functools
import networkx as nx
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import normalize, OneHotEncoder
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.utils import train_test_split_edges
g_seed=39788
torch.set_num_threads(8)
np.random.seed(g_seed)
torch.manual_seed(g_seed)
torch.use_deterministic_algorithms(True)


def fair_metric(output, labels, sens):
    val_y = labels
    idx_s0 = sens.cpu().numpy()==0
    idx_s1 = sens.cpu().numpy()==1
    idx_s0_y1 = np.bitwise_and(idx_s0,val_y==1)
    idx_s1_y1 = np.bitwise_and(idx_s1,val_y==1)

    pred_y = output
    parity = abs(sum(pred_y[idx_s0])/sum(idx_s0)-sum(pred_y[idx_s1])/sum(idx_s1))
    equality = abs(sum(pred_y[idx_s0_y1])/sum(idx_s0_y1)-sum(pred_y[idx_s1_y1])/sum(idx_s1_y1))

    return parity,equality

def fair_metric_mc(output, labels, sens):
    val_y = labels
    idx_s0 = sens.cpu().numpy()==0
    idx_s1 = sens.cpu().numpy()==1

    pred_y = output
    
    parity =abs((len(np.where(pred_y[idx_s0]!=val_y[idx_s0])[0])/len(idx_s0))-(len(np.where(pred_y[idx_s1]!=val_y[idx_s1])[0])/len(idx_s1)))

    return parity

def repeat(n_times):
    def decorator(f):
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            results = [f(*args, **kwargs) for _ in range(n_times)]
            statistics = {}
            for key in results[0].keys():
                values = [r[key] for r in results]
                statistics[key] = {
                    'mean': np.mean(values),
                    'std': np.std(values)}
            print_statistics(statistics, f.__name__)
            return statistics
        return wrapper
    return decorator

def prob_to_one_hot(y_pred):
    ret = np.zeros(y_pred.shape, np.bool)
    indices = np.argmax(y_pred, axis=1)
    for i in range(y_pred.shape[0]):
        ret[i][indices[i]] = True
    return ret


def print_statistics(statistics, function_name):
    print(f'(E) | {function_name}:', end=' ')
    for i, key in enumerate(statistics.keys()):
        mean = statistics[key]['mean']
        std = statistics[key]['std']
        print(f'{key}={mean:.4f}+-{std:.4f}', end='')
        if i != len(statistics.keys()) - 1:
            print(',', end=' ')
        else:
            print()
def maximize_over_t(inter,intra):
    t=np.arange(0,1,0.05)
    cur_max=0
    optimized_t=0
    for i,val in enumerate(t):
        cand=np.absolute(len(np.where(inter < val)[0])/len(inter)-len(np.where(intra < val)[0])/len(intra))
        if cand>cur_max:
            cur_max=cand
            optimized_t=val
    return cur_max
@repeat(3)
def label_classification_embed(embeddings, y, sens, indices_train, indices_test):
    X = embeddings.detach().cpu().numpy()
    Y = y.detach().cpu().numpy()
    Y = Y.reshape(-1, 1)
    onehot_encoder = OneHotEncoder(categories='auto').fit(Y)
    Y = onehot_encoder.transform(Y).toarray().astype(np.bool)

    X = normalize(X, norm='l2')
    indices=range(np.shape(X)[0])
    X_train, X_test, y_train, y_test,indices_train,indices_test = train_test_split(X, Y,indices,
                                                        test_size=1 - 0.1)

    logreg = LogisticRegression(solver='liblinear')
    c = 2.0 ** np.arange(-10, 10)

    clf = GridSearchCV(estimator=OneVsRestClassifier(logreg),
                       param_grid=dict(estimator__C=c), n_jobs=8, cv=5,
                       verbose=0)
    clf.fit(X_train, y_train)

    y_pred = clf.predict_proba(X_test)
    y_pred = prob_to_one_hot(y_pred)

    micro = f1_score(y_test, y_pred, average="micro")
    macro = f1_score(y_test, y_pred, average="macro")
    parity,equality=fair_metric(np.argmax(y_pred,axis=1),np.argmax(y_test,axis=1),sens[indices_test])
    return {
        'F1Mi': micro,
        'F1Ma': macro,
        'parity': parity,
        'equality': equality
    }

def maximize_over_t(inter,intra):
    t=np.arange(0,1,0.05)
    cur_max=0
    optimized_t=0
    for i,val in enumerate(t):
        cand=np.absolute(len(np.where(inter < val)[0])/len(inter)-len(np.where(intra < val)[0])/len(intra))
        if cand>cur_max:
            cur_max=cand
            optimized_t=val
    return cur_max
@repeat(3)
def label_classification(embeddings, y, sens, ratio):
    X = embeddings.detach().cpu().numpy()
    Y = y.detach().cpu().numpy()
    Y = Y.reshape(-1, 1)
    onehot_encoder = OneHotEncoder(categories='auto').fit(Y)
    Y = onehot_encoder.transform(Y).toarray().astype(np.bool)

    X = normalize(X, norm='l2')
    indices=range(np.shape(X)[0])
    X_train, X_test, y_train, y_test,indices_train,indices_test = train_test_split(X, Y,indices,
                                                                                   test_size=1 - ratio)

    logreg = LogisticRegression(solver='liblinear')
    c = 2.0 ** np.arange(-10, 10)

    clf = GridSearchCV(estimator=OneVsRestClassifier(logreg),
                       param_grid=dict(estimator__C=c), n_jobs=5, cv=5,
                       verbose=0)
    clf.fit(X_train, y_train)

    y_pred = clf.predict_proba(X_test)
    y_pred = prob_to_one_hot(y_pred)

    micro = f1_score(y_test, y_pred, average="micro")
    macro = f1_score(y_test, y_pred, average="macro")
    acc=accuracy_score(y_test, y_pred)
    roc_auc=roc_auc_score(y_test, y_pred)

    if np.shape(y_pred)[1]>2:
        acc_parity=fair_metric_mc(np.argmax(y_pred,axis=1),np.argmax(y_test,axis=1),sens[indices_test])
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'acc_parity': acc_parity
        }
    else:
        parity,equality=fair_metric(np.argmax(y_pred,axis=1),np.argmax(y_test,axis=1),sens[indices_test])
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'parity': parity,
            'equality': equality
        }
@repeat(3)
def sens_classification(embeddings, y, ratio):
    X = embeddings.detach().cpu().numpy()
    Y = y.detach().cpu().numpy()
    Y = Y.reshape(-1, 1)
    onehot_encoder = OneHotEncoder(categories='auto').fit(Y)
    Y = onehot_encoder.transform(Y).toarray().astype(np.bool)

    X = normalize(X, norm='l2')
    indices=range(np.shape(X)[0])
    X_train, X_test, y_train, y_test,indices_train,indices_test = train_test_split(X, Y,indices,
                                                                                   test_size=1 - ratio)

    logreg = LogisticRegression(solver='liblinear')
    c = 2.0 ** np.arange(-10, 10)

    clf = GridSearchCV(estimator=OneVsRestClassifier(logreg),
                       param_grid=dict(estimator__C=c), n_jobs=5, cv=5,
                       verbose=0)
    clf.fit(X_train, y_train)

    y_pred = clf.predict_proba(X_test)
    y_pred = prob_to_one_hot(y_pred)

    micro = f1_score(y_test, y_pred, average="micro")
    macro = f1_score(y_test, y_pred, average="macro")
    acc=accuracy_score(y_test, y_pred)
    roc_auc=roc_auc_score(y_test, y_pred) 
    rb=0
    all_s=len(np.argmax(y_test,axis=1))
    for i,s in enumerate(np.unique(y.detach().cpu().numpy())):
        ind_s=np.where(np.argmax(y_pred,axis=1)==s)[0]
        acc=accuracy_score(y_test[ind_s], y_pred[ind_s])
        rb=rb+(float(len(ind_s))/np.shape(y_pred)[0])*acc

    return {'rb' : rb,
            'roc_auc' : roc_auc,
            'accuracy' : acc}

def link_prediction(embeddings, edges_tr, edges_t, neg_edges_tr, neg_edges_t, sens):
    X = embeddings.detach().cpu().numpy()
    edges_tr = edges_tr.detach().cpu().numpy().T
    #edges_val = edges_val.detach().cpu().numpy().T
    edges_t = edges_t.detach().cpu().numpy().T
    
       
    X = normalize(X, norm='l2')
    
    X_tr=np.concatenate((X[edges_tr[:,0]],X[edges_tr[:,1]]),axis=1)
    y_tr=np.ones(np.shape(X_tr)[0])
    sens_tr=np.zeros(np.shape(X_tr)[0])
    sens_tr[np.where((sens[edges_tr[:,0]] != sens[edges_tr[:,1]]) == True)[0]]=1
    X_neg_tr=np.concatenate((X[neg_edges_tr[:,0]],X[neg_edges_tr[:,1]]),axis=1)
    y_neg_tr=np.zeros(np.shape(X_neg_tr)[0])
    sens_neg_tr=np.zeros(np.shape(X_neg_tr)[0])
    sens_neg_tr[np.where((sens[neg_edges_tr[:,0]] != sens[neg_edges_tr[:,1]]) == True)[0]]=1
    
    X_all_tr=np.concatenate((X_tr,X_neg_tr),axis=0)
    y_all_tr=np.concatenate((y_tr,y_neg_tr),axis=0)
    sens_all_tr=np.concatenate((sens_tr,sens_neg_tr),axis=0)
    
    indices_tr = np.arange(np.shape(X_all_tr)[0])
    import random
    seed=19
    random.seed(seed)
    random.shuffle(indices_tr)
    
    X_all_tr=X_all_tr[indices_tr,:]
    y_all_tr=y_all_tr[indices_tr]
    sens_all_tr=sens_all_tr[indices_tr]     
    
    y_all_tr=y_all_tr.reshape(-1, 1)    
    onehot_encoder = OneHotEncoder(categories='auto').fit(y_all_tr)
    Y_all_tr = onehot_encoder.transform(y_all_tr).toarray().astype(np.bool)

    X_t=np.concatenate((X[edges_t[:,0]],X[edges_t[:,1]]),axis=1)
    y_t=np.ones(np.shape(X_t)[0])
    sens_t=np.zeros(np.shape(X_t)[0])
    sens_t[np.where((sens[edges_t[:,0]] != sens[edges_t[:,1]]) == True)[0]]=1
    X_neg_t=np.concatenate((X[neg_edges_t[:,0]],X[neg_edges_t[:,1]]),axis=1)
    y_neg_t=np.zeros(np.shape(X_neg_t)[0])
    sens_neg_t=np.zeros(np.shape(X_neg_t)[0])
    sens_neg_t[np.where((sens[neg_edges_t[:,0]] != sens[neg_edges_t[:,1]]) == True)[0]]=1
    
    X_all_t=np.concatenate((X_t,X_neg_t),axis=0)
    y_all_t=np.concatenate((y_t,y_neg_t),axis=0)
    sens_all_t=np.concatenate((sens_t,sens_neg_t),axis=0)
    
    indices_t = np.arange(np.shape(X_all_t)[0])
    import random
    seed=19
    random.seed(seed)
    random.shuffle(indices_t)
    
    X_all_t=X_all_t[indices_t,:]
    y_all_t=y_all_t[indices_t]
    sens_all_t=sens_all_t[indices_t] 
    
    
    y_all_t=y_all_t.reshape(-1, 1)
    onehot_encoder = OneHotEncoder(categories='auto').fit(y_all_t)
    Y_all_t = onehot_encoder.transform(y_all_t).toarray().astype(np.bool)
    
    
    logreg = LogisticRegression(solver='liblinear')
    c = 2.0 ** np.arange(-10, 10)

    clf = GridSearchCV(estimator=OneVsRestClassifier(logreg),
                       param_grid=dict(estimator__C=c), n_jobs=5, cv=5,
                       verbose=0)
    clf.fit(X_all_tr, Y_all_tr)

    y_pred = clf.predict_proba(X_all_t)
    y_pred = prob_to_one_hot(y_pred)
    
    micro = f1_score(Y_all_t, y_pred, average="micro")
    macro = f1_score(Y_all_t, y_pred, average="macro")
    acc=accuracy_score(Y_all_t, y_pred)
    roc_auc=roc_auc_score(Y_all_t, y_pred)

    if np.shape(y_pred)[1]>2:
        acc_parity=fair_metric_mc(np.argmax(y_pred,axis=1),np.argmax(Y_all_t,axis=1),sens_all_t)
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'acc_parity': acc_parity
        }
    else:
        parity,equality=fair_metric(np.argmax(y_pred,axis=1),np.argmax(Y_all_t,axis=1),torch.LongTensor(sens_all_t))
        return {
            'roc_auc' : roc_auc,
            'accuracy' : acc,
            'F1Mi': micro,
            'F1Ma': macro,
            'parity': parity,
            'equality': equality
            
        }

def sensitive_edges(edges, sens):
    same_ind=[]
    diff_ind=[]

    mask=(sens[edges[0,:]]==sens[edges[1,:]])


    #for i in range(edges.shape[1]):
     #   #if(edges[0,i]==edges[1,i]):
     #   #    same_ind.append(i)
     #   #    diff_ind.append(i)
     #   if (sens[edges[0, i]] == sens[edges[1, i]] ):
     #       same_ind.append(i)
     #   else:
     #       diff_ind.append(i)
    return edges[:,np.where(mask==True)[0]], edges[:,np.where(mask==False)[0]]

def feature_norm(features):
    feat_mean=torch.mean(features,0)
    feat_std=torch.std(features,0)
    return (features- feat_mean)/feat_std

def node_stats_msens(edges, sens):
    row, col = edges

    intra=np.where((sens[row] == sens[col]) == True)[0]
    inter=np.where((sens[row] != sens[col]) == True)[0]
    sens_values=np.unique(sens)
    node_ids={}
    len_dict={}
    for i,s in enumerate(sens_values):
        node_ids[s]=np.where(sens==s)[0]
        len_dict[s]=len(np.where(sens==s)[0])

    edges=np.array(edges).T

    all_tildes=np.unique(edges[inter,:].flatten())
    tilde_dict={}
    hat_dict={}
    for i,s in enumerate(sens_values):
        tilde_dict[s]=all_tildes[np.where(sens[all_tildes]==s)[0]]
        hat_dict[s]=np.array(list(set(node_ids[s]).difference(set(tilde_dict[s]))))
        print('the number of '+str(s)+' tilde: ', len(tilde_dict[s]))
        print('the number of '+str(s)+' hat: ', len(hat_dict[s]))

    print('the number of inter edges: ', len(inter))
    print('the number of intra edges: ', len(intra))

def load_fb(path, dataset):
    mat = scipy.io.loadmat(path+'/'+dataset)
    Adj=mat['A']
    feats=mat['local_info']
    
    idx_used=[]
    for i in range(np.shape(feats)[0]):
        if(0 not in feats[i,:]):
            idx_used.append(i)
    
    idx_nonused = np.asarray(list(set(np.arange(np.shape(feats)[0])).difference(set(idx_used))))
    #Sensitive attr is gender     
    sens=np.array(feats[idx_used,1]-1)
    
    feats=feats[idx_used,:]
    feats=feats[:,[0,2,3,4,5,6]]
    
    edges=np.concatenate((np.reshape(scipy.sparse.find(Adj)[0],(len(scipy.sparse.find(Adj)[0]),1)),np.reshape(scipy.sparse.find(Adj)[1],(len(scipy.sparse.find(Adj)[1]),1))),axis=1)

                         
    used_ind1 = [i for i, elem in enumerate(edges[:, 0]) if elem not in idx_nonused]
    used_ind2 = [i for i, elem in enumerate(edges[:, 1]) if elem not in idx_nonused]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges = edges[intersect_ind, :]

    idx_map = {j: i for i, j in enumerate(idx_used)}
    edges = np.array(list(map(idx_map.get, edges.flatten())),
                            dtype=int).reshape(edges.shape)
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                            shape=(sens.shape[0], sens.shape[0]),
                            dtype=np.float32)
                         
    G = nx.from_scipy_sparse_matrix(adj)
    g_nx_ccs = (G.subgraph(c).copy() for c in nx.connected_components(G))
    g_nx = max(g_nx_ccs, key=len)

    import random
    seed=19
    random.seed(seed)
    node_ids = list(g_nx.nodes())
    idx_s=node_ids
    random.shuffle(idx_s)
                         
    feats=feats[idx_s,:]
    feats=feats[:,np.where(np.std(np.array(feats),axis=0)!=0)[0]] 
    feats=torch.FloatTensor(np.array(feats,dtype=float))
    
    sens=torch.LongTensor(np.array(sens[idx_s],dtype=int))  
                         
    idx_map_n = {j: int(i) for i, j in enumerate(idx_s)}

    idx_nonused2 = np.asarray(list(set(np.arange(len(list(G.nodes())))).difference(set(idx_s))))
    used_ind1 = [i for i, elem in enumerate(edges[:, 0]) if elem not in idx_nonused2]
    used_ind2 = [i for i, elem in enumerate(edges[:, 1]) if elem not in idx_nonused2]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges = edges[intersect_ind, :]                     
    edges = np.array(list(map(idx_map_n.get, edges.flatten())),
                    dtype=int).reshape(edges.shape)     
    

    #edges=edges[order,:]
    #num_edges=np.shape(edges)[0]
    #edges_train = edges[:int(0.9*num_edges),:]
    #edges_val = edges[int(0.8*num_edges):int(0.9*num_edges),:]
    #edges_test = edges[int(0.9*num_edges):,:]
    
    
    #adj = sp.coo_matrix((np.ones(edges_train.shape[0]), (edges_train[:, 0], edges_train[:, 1])),
    #                    shape=(sens.shape[0], sens.shape[0]),
    #                    dtype=np.float32)
    #degs=np.sum(adj.toarray(), axis=1)+np.ones(len(np.sum(adj.toarray(), axis=1)))
    #edges_train = torch.LongTensor(edges_train.T)
    
    #edges_val = torch.LongTensor(edges_val.T)
    #edges_test = torch.LongTensor(edges_test.T)
    return edges, feats, sens                  

def node_stats(edges, sens):
    row, col = edges

    intra=np.where((sens[row] == sens[col]) == True)[0]                                                                                                                                                                                                    
    inter=np.where((sens[row] != sens[col]) == True)[0]

    node_ids_s0=np.where(sens==0)[0]
    node_ids_s1=np.where(sens==1)[0]

    s0=len(node_ids_s0)
    s1=len(node_ids_s1)

    edges=np.array(edges).T

    all_tildes=np.unique(edges[inter,:].flatten())
    s0_tilde=all_tildes[np.where(sens[all_tildes]==0)[0]]
    s0_hat= np.array(list(set(node_ids_s0).difference(set(s0_tilde))))

    s1_tilde=all_tildes[np.where(sens[all_tildes]==1)[0]]
    s1_hat= np.array(list(set(node_ids_s1).difference(set(s1_tilde))))
    
    intra_s0=len(np.where(sens[edges[intra,0]]==0)[0])
    intra_s1=len(intra)-intra_s0

    print('the number of s_0 tilde: ', len(s0_tilde))
    print('the number of s_0 hat: ', len(s0_hat))               
    print('the number of s_0: ',s0)
    print('the number of s_1 tilde: ', len(s1_tilde))
    print('the number of s_1 hat: ', len(s1_hat))
    print('the number of s_1: ',s1)
    print('the number of inter edges: ', len(inter))
    print('the number of intra edges: ', len(intra))
    print('the value of r0: ',float(len(inter))/(2*intra_s0))
    print('The value of r1: ',float(len(inter))/(2*intra_s1))
                     
def load_pokec(dataset, sens_attr, predict_attr, path="pokec_dataset/", tris=False, degs=False):
    """Load data"""
    print('Loading {} dataset from {}'.format(dataset, path))
    idx_features_labels = pd.read_csv(os.path.join(path, "{}.csv".format(dataset)))
    
    header = list(idx_features_labels.columns)
    header.remove("user_id")
    header.remove(sens_attr)
    header.remove(predict_attr)

    features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32)
    labels = idx_features_labels[predict_attr].values
    sens = idx_features_labels[sens_attr].values
    #Only nodes for which label and sensitive attributes are available are utilized 
    sens_idx = set(np.where(sens >= 0)[0])
    label_idx = np.where(labels >= 0)[0]
    idx_used = np.asarray(list(sens_idx & set(label_idx)))
    idx_nonused = np.asarray(list(set(np.arange(len(labels))).difference(set(idx_used))))

    features = features[idx_used, :]
    labels = labels[idx_used]
    sens = sens[idx_used]

    idx = np.array(idx_features_labels["user_id"], dtype=int)
    edges_unordered = np.genfromtxt(os.path.join(path, "{}_relationship.txt".format(dataset)), dtype=int)

    idx_n = idx[idx_nonused]
    idx = idx[idx_used]
    used_ind1 = [i for i, elem in enumerate(edges_unordered[:, 0]) if elem not in idx_n]
    used_ind2 = [i for i, elem in enumerate(edges_unordered[:, 1]) if elem not in idx_n]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges_unordered = edges_unordered[intersect_ind, :]
    # build graph

    idx_map = {j: i for i, j in enumerate(idx)}
    edges_un = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                        dtype=int).reshape(edges_unordered.shape)

    
    adj = sp.coo_matrix((np.ones(edges_un.shape[0]), (edges_un[:, 0], edges_un[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)
    G = nx.from_scipy_sparse_matrix(adj)
    g_nx_ccs = (G.subgraph(c).copy() for c in nx.connected_components(G))
    g_nx = max(g_nx_ccs, key=len)

    import random
    seed=19
    random.seed(seed)
    node_ids = list(g_nx.nodes())
    idx_s=node_ids
    random.shuffle(idx_s)
    
    features=features[idx_s,:]
    features=features[:,np.where(np.std(np.array(features.todense()),axis=0)!=0)[0]] 
    
    features=torch.FloatTensor(np.array(features.todense()))
    labels=torch.LongTensor(labels[idx_s])
    
    sens=torch.LongTensor(sens[idx_s])
    labels[labels > 1] = 1
    sens[sens > 0] = 1
    idx_map_n = {j: int(i) for i, j in enumerate(idx_s)}

    idx_nonused2 = np.asarray(list(set(np.arange(len(list(G.nodes())))).difference(set(idx_s))))
    used_ind1 = [i for i, elem in enumerate(edges_un[:, 0]) if elem not in idx_nonused2]
    used_ind2 = [i for i, elem in enumerate(edges_un[:, 1]) if elem not in idx_nonused2]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges_un = edges_un[intersect_ind, :]
    edges = np.array(list(map(idx_map_n.get, edges_un.flatten())),
                     dtype=int).reshape(edges_un.shape)
    edges=np.unique(edges, axis=0)
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)
    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
    degs=np.sum(adj.toarray(), axis=1)+np.ones(len(np.sum(adj.toarray(), axis=1)))
    edges= np.concatenate((np.reshape(scipy.sparse.find(adj)[0],(len(scipy.sparse.find(adj)[0]),1)),np.reshape(scipy.sparse.find(adj)[1],(len(scipy.sparse.find(adj)[1]),1))),axis=1)
    g_nx = nx.from_scipy_sparse_matrix(adj)
    edges = torch.LongTensor(edges.T)
#    if degs==True:
#        return edges, features, labels, sens, np.sum(adj.toarray(), axis=1)
    if tris==True:
        all_cliques = nx.enumerate_all_cliques(g_nx)
        triad_cliques = [x for x in all_cliques if len(x) == 3]
        all_cliques = []
        return edges, features, labels, sens, np.asarray(triad_cliques)
    else:
        return edges, features, labels, sens, degs


