import numpy as np
import scipy.sparse as sp
import scipy.io
import torch
import os
import os.path as osp
import sys
import pandas as pd
import dgl
from torch_geometric.utils import from_scipy_sparse_matrix
import functools
import networkx as nx
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import normalize, OneHotEncoder
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.utils import train_test_split_edges
g_seed=39788
torch.set_num_threads(8)
np.random.seed(g_seed)
torch.manual_seed(g_seed)

def load_fbucsd(path='../../datasets/socfb-UNC28'):
    mat = scipy.io.loadmat(path+'/UNC28.mat')
    Adj=mat['A']
    feats=mat['local_info']

    idx_used=[]
    for i in range(np.shape(feats)[0]):
        if(0 not in feats[i,:]):
            idx_used.append(i)

    idx_nonused = np.asarray(list(set(np.arange(np.shape(feats)[0])).difference(set(idx_used))))
    #Sensitive attr is gender                                                                                                                                                                                                                              
    sens=np.array(feats[idx_used,1]-1)

    feats=feats[idx_used,:]
    feats=feats[:,[0,2,3,4,5,6]]
    edges=np.concatenate((np.reshape(scipy.sparse.find(Adj)[0],(len(scipy.sparse.find(Adj)[0]),1)),np.reshape(scipy.sparse.find(Adj)[1],(len(scipy.sparse.find(Adj)[1]),1))),axis=1)



    used_ind1 = [i for i, elem in enumerate(edges[:, 0]) if elem not in idx_nonused]
    used_ind2 = [i for i, elem in enumerate(edges[:, 1]) if elem not in idx_nonused]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges = edges[intersect_ind, :]

    idx_map = {j: i for i, j in enumerate(idx_used)}
    edges = np.array(list(map(idx_map.get, edges.flatten())),
                            dtype=int).reshape(edges.shape)
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                            shape=(sens.shape[0], sens.shape[0]),
                            dtype=np.float32)

    G = nx.from_scipy_sparse_matrix(adj)
    g_nx_ccs = (G.subgraph(c).copy() for c in nx.connected_components(G))
    g_nx = max(g_nx_ccs, key=len)

    import random
    seed=19
    random.seed(seed)
    node_ids = list(g_nx.nodes())
    idx_s=node_ids
    random.shuffle(idx_s)

    feats=feats[idx_s,:]
    feats=feats[:,np.where(np.std(np.array(feats),axis=0)!=0)[0]]
    feats=torch.FloatTensor(np.array(feats,dtype=float))

    sens=torch.LongTensor(np.array(sens[idx_s],dtype=int))

    idx_map_n = {j: int(i) for i, j in enumerate(idx_s)}

    idx_nonused2 = np.asarray(list(set(np.arange(len(list(G.nodes())))).difference(set(idx_s))))
    used_ind1 = [i for i, elem in enumerate(edges[:, 0]) if elem not in idx_nonused2]
    used_ind2 = [i for i, elem in enumerate(edges[:, 1]) if elem not in idx_nonused2]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges = edges[intersect_ind, :]
    edges = np.array(list(map(idx_map_n.get, edges.flatten())),
                    dtype=int).reshape(edges.shape)
    edge_idx=np.arange(np.shape(edges)[0])
    random.shuffle(edge_idx)
    edges=edges[edge_idx,:]
    num_edges=np.shape(edges)[0]
    edges_train = edges[:int(0.8*num_edges),:]
    #edges_val = edges[int(0.8*num_edges):int(0.9*num_edges),:]                                                                                                                                                                                            
    edges_test = edges[int(0.8*num_edges):,:]


    adj = sp.coo_matrix((np.ones(edges_train.shape[0]), (edges_train[:, 0], edges_train[:, 1])),
                        shape=(sens.shape[0], sens.shape[0]),
                        dtype=np.float32)
    G = nx.from_scipy_sparse_matrix(adj)
    nx.write_edgelist(G, "lp_unc_08.edgelist")
    degs=np.sum(adj.toarray(), axis=1)+np.ones(len(np.sum(adj.toarray(), axis=1)))
    edges_train = torch.LongTensor(edges_train.T)
    #edges_val = torch.LongTensor(edges_val.T)                                                                                                                                                                                                             
    edges_test = torch.LongTensor(edges_test.T)
    return edges_train, edges_test, feats, sens, degs

load_fbucsd()
