import random
random.seed(10)
import argparse
from torch_geometric.utils import negative_sampling
import sys
import scipy.sparse as sp
import scipy.io
import os.path as osp
import numpy as np
from time import perf_counter as t
import argparse
import torch
import torch_geometric.transforms as T
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch import Tensor
from model import Encoder, Model
import networkx as nx
from utils import load_pokec, feature_norm, load_fb, link_prediction, print_statistics, node_stats, node_stats_msens, sens_classification
from eval import label_classification


def maybe_num_nodes_fair(edge_index, num_nodes=None):
    if num_nodes is not None:
        return num_nodes
    elif isinstance(edge_index, Tensor):
        return int(edge_index.max()) + 1
    else:
        return max(edge_index.size(0), edge_index.size(1))
def filter_adj_fair(row, col, mask):
    return row[mask], col[mask]

def dropout_adj(edge_index, p=0.25):
    row, col = edge_index
    ma= edge_index.new_full((row.size(0),), (1-p), dtype=torch.float)
    ma = torch.bernoulli(ma).to(torch.bool)
    row, col = filter_adj_fair(row, col, ma)
    edge_index = torch.stack([row, col], dim=0)

    return edge_index


def dropout_adj_fair(edge_index, sens, r1, r2, pk=0.9, pmin=0.4, pmax=1):
    row, col = edge_index

    ma= edge_index.new_full((row.size(0),), pk, dtype=torch.float)
    
    ma[np.where((sens[row] != sens[col]) == True)[0]]= pk
    ma[np.where((torch.logical_and(sens[row] == sens[col], sens[row]==torch.zeros(row.size(0)))) == True)[0]] =  min(pmax,max(pk*r1,pmin))
    ma[np.where((torch.logical_and(sens[row] == sens[col], sens[row]==torch.ones(row.size(0)))) == True)[0]] = min(pmax,max(pk*r2,pmin))
    ma = torch.bernoulli(ma).to(torch.bool)

    row, col = filter_adj_fair(row, col, ma)
    edge_index = torch.stack([row, col], dim=0)

    return edge_index

def dropout_adj_fair_general(edge_index, sens, pk=0.9, pmin=0.4, pmax=1):
    row, col = edge_index

    inter=np.where((sens[row] != sens[col]) == True)[0]
    intra=np.where((sens[row] == sens[col]) == True)[0]

    r=float(len(inter))/len(intra)

    ma= edge_index.new_full((row.size(0),), pk, dtype=torch.float)
    ma[inter]= pk
    ma[intra] = min(pmax,max(pk*r,pmin))
    ma = torch.bernoulli(ma).to(torch.bool)
    row, col = filter_adj_fair(row, col, ma)
    edge_index = torch.stack([row, col], dim=0)

    return edge_index


def drop_feature(x, drop_prob):
    drop_mask = torch.empty(
        (x.size(1), ),
        dtype=torch.float32,
        device=x.device).uniform_(0, 1) < drop_prob
    x = x.clone()
    x[:, drop_mask] = 0

    return x

def W(x):
    if len(np.where(x>1)[0])>0:
        x[np.where(x>1)[0]]=1
    if len(np.where(x<0)[0])>0:
        x[np.where(x<0)[0]]=0

    return x
def topology_aware_disc(features, sens, mean_prob_view1, mean_prob_view2):
    feat1=np.mean(np.array(features[np.where(sens==0)[0],:]),axis=0)
    feat2=np.mean(np.array(features[np.where(sens==1)[0],:]),axis=0)
    disc=np.absolute(feat1-feat2)

    remove_prob=(disc-np.min(disc))/(np.max(disc)-np.min(disc))                                                                   
    remove_prob1=W(remove_prob*(mean_prob_view1/np.mean(remove_prob)))
    remove_prob2=W(remove_prob*(mean_prob_view2/np.mean(remove_prob)))
    return remove_prob1, remove_prob2
def drop_feature_topology_aware(x,remove_prob):
    remove_prob = torch.FloatTensor(remove_prob)
    drop_mask = torch.empty(
        (x.size(1),),
        dtype=torch.float32,
        device=x.device).uniform_(0, 1) < remove_prob
    x = x.clone()
    x[:, drop_mask] = 0
    return x

def inter_add_uniform(edges,sens):
    row, col = edges

    intra=np.where((sens[row] == sens[col]) == True)[0]
    inter=np.where((sens[row] != sens[col]) == True)[0]

    edges=np.array(edges).T

    intra_edges=edges[intra,:]
    inter_edges=edges[inter,:]

    node_ids_s0=np.where(sens==0)[0]
    node_ids_s1=np.where(sens==1)[0]

    s0=len(node_ids_s0)
    s1=len(node_ids_s1)
    
    all_tildes=np.unique(edges[inter,:].flatten())
    s0_tilde=all_tildes[np.where(sens[all_tildes]==0)[0]]
    #s0_hat= np.array(list(set(node_ids_s0).difference(set(s0_tilde))))                                                                                                                                                                                     

    s1_tilde=all_tildes[np.where(sens[all_tildes]==1)[0]]
    #s1_hat= np.array(list(set(node_ids_s1).difference(set(s1_tilde))))   
    
    edge_num=int(len(intra)-len(inter))

    if edge_num>0:
        added_s0=s0_tilde[torch.multinomial(torch.full([np.shape(s0_tilde)[0]], 1/float(np.shape(s0_tilde)[0])), edge_num, replacement=True, generator=None, out=None)]
        if edge_num>1:
            added_s0=np.reshape(added_s0,(len(added_s0),1))
        else:
            added_s0=np.reshape(added_s0,(1,1))
        added_s1=s1_tilde[torch.multinomial(torch.full([np.shape(s1_tilde)[0]], 1/float(np.shape(s1_tilde)[0])), edge_num, replacement=True, generator=None, out=None)]
        if edge_num>1:
            added_s1=np.reshape(added_s1,(len(added_s1),1))
        else:
            added_s1=np.reshape(added_s1,(1,1))
        edges=np.concatenate((edges,np.concatenate((added_s0,added_s1),axis=1)),axis=0)

    edges = torch.LongTensor(edges.T)
    return edges                  


def node_sample_uni(edges, features, sens, s0_tilde, s1_tilde, s0_hat, s1_hat):
    row, col = edges

    inter=np.where((sens[row] != sens[col]) == True)[0]

    node_ids_s0=np.where(sens==0)[0]
    node_ids_s1=np.where(sens==1)[0]

    s0=len(node_ids_s0)
    s1=len(node_ids_s1)

    edges=np.array(edges).T

    if len(s1_tilde) > len(s1_hat) and len(s0_tilde) > len(s0_hat):
        if len(s0_hat)>= int(len(s0_tilde)/2) and len(s1_hat)>= int(len(s1_tilde)/2):
            sampled_s1=s1_tilde[torch.multinomial(torch.full([np.shape(s1_tilde)[0]], 1/float(np.shape(s1_tilde)[0])),int(1.1*len(s1_hat)), replacement=False, generator=None, out=None)]
            used_s1 = np.concatenate((s1_hat,sampled_s1))
            sampled_s0=s0_tilde[torch.multinomial(torch.full([np.shape(s0_tilde)[0]], 1/float(np.shape(s0_tilde)[0])),int(1.1*len(s0_hat)), replacement=False, generator=None, out=None)]
            used_s0 = np.concatenate((s0_hat,sampled_s0))
        else:
            sampled_s0=s0_tilde[torch.multinomial(torch.full([np.shape(s0_tilde)[0]], 1/float(np.shape(s0_tilde)[0])),int(len(s0_tilde)/2), replacement=False, generator=None, out=None)]
            used_s0 = np.concatenate((s0_tilde,sampled_s0))
            sampled_s1=s1_tilde[torch.multinomial(torch.full([np.shape(s1_tilde)[0]], 1/float(np.shape(s1_tilde)[0])),int(len(s1_tilde)/2), replacement=False, generator=None, out=None)]
            used_s1 = np.concatenate((s1_hat,sampled_s1))
    else:
        if len(s0_tilde)>= int(len(s0_hat)/4) and len(s1_tilde)>= int(len(s1_hat)/4):
            sampled_s0=s0_hat[torch.multinomial(torch.full([np.shape(s0_hat)[0]], 1/float(np.shape(s0_hat)[0])), len(s0_tilde), replacement=False, generator=None, out=None)]
            used_s0 = np.concatenate((s0_tilde,sampled_s0))
            sampled_s1=s1_hat[torch.multinomial(torch.full([np.shape(s1_hat)[0]], 1/float(np.shape(s1_hat)[0])), len(s1_tilde), replacement=False, generator=None, out=None)]
            used_s1 = np.concatenate((s1_tilde,sampled_s1))
        else:
            sampled_s0=s0_hat[torch.multinomial(torch.full([np.shape(s0_hat)[0]],1/float(np.shape(s0_hat)[0])),int(len(s0_hat)/4), replacement=False, generator=None, out=None)]
            used_s0 = np.concatenate((s0_tilde,sampled_s0))
            sampled_s1=s1_hat[torch.multinomial(torch.full([np.shape(s1_hat)[0]], 1/float(np.shape(s1_hat)[0])),int(len(s1_hat)/4), replacement=False, generator=None, out=None)]
            used_s1 = np.concatenate((s1_tilde,sampled_s1))            
    sampled_nodes= np.unique(np.concatenate((used_s0, used_s1)))

    idx_map_n = {j: int(i) for i, j in enumerate(sampled_nodes)}

    #nodes_nonused = np.array(list(set(np.arange(len(labels))).difference(set(sampled_nodes))))                                                                                                                                                           
    used_ind1 = [i for i, elem in enumerate(edges[:, 0]) if elem in sampled_nodes]
    used_ind2 = [i for i, elem in enumerate(edges[:, 1]) if elem in sampled_nodes]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges = edges[intersect_ind, :]
    edges = np.array(list(map(idx_map_n.get, edges.flatten())),
                     dtype=int).reshape(edges.shape)
    features = features[sampled_nodes, :]
    sens = sens[sampled_nodes]
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(sens.shape[0], sens.shape[0]),
                        dtype=np.float32)
    degs=np.sum(adj.toarray(), axis=1)+np.ones(len(np.sum(adj.toarray(), axis=1)))

    edges = torch.LongTensor(edges.T)
    return edges, features, sens, degs


def node_sample_group(edges, features, sens, s0_tilde, s1_tilde, s0_hat, s1_hat):
    row, col = edges

    inter=np.where((sens[row] != sens[col]) == True)[0]

    node_ids_s0=np.where(sens==0)[0]
    node_ids_s1=np.where(sens==1)[0]
    

    s0=len(node_ids_s0)
    s1=len(node_ids_s1)
    
    r=float(s0)/s1

    edges=np.array(edges).T

    if len(s1_tilde) > len(s0_hat) and len(s0_tilde) > len(s1_hat):
        if len(s0_hat)>= int(len(s1_tilde)/2) and len(s1_hat)>= int(len(s0_tilde)/2):
            sampled_s1=s1_tilde[torch.multinomial(torch.full([np.shape(s1_tilde)[0]], 1/float(np.shape(s1_tilde)[0])),int(1.1*len(s0_hat)), replacement=False, generator=None, out=None)]
            used_s1 = np.concatenate((s1_hat,sampled_s1))
            sampled_s0=s0_tilde[torch.multinomial(torch.full([np.shape(s0_tilde)[0]], 1/float(np.shape(s0_tilde)[0])),int(1.1*len(s1_hat)), replacement=False, generator=None, out=None)]
            used_s0 = np.concatenate((s0_hat,sampled_s0))
        else:
            if len(s0_tilde)/2 > len(s1_tilde)/2:
                sampled_s0=s0_tilde[torch.multinomial(torch.full([np.shape(s0_tilde)[0]], 1/float(np.shape(s0_tilde)[0])),int(len(s0_tilde)/2), replacement=False, generator=None, out=None)]
                used_s0 = np.concatenate((s0_tilde,sampled_s0))
                sampled_s1=s1_tilde[torch.multinomial(torch.full([np.shape(s1_tilde)[0]], 1/float(np.shape(s1_tilde)[0])),int(len(s0_tilde)/2)-len(s1_hat)+len(s0_hat), replacement=False, generator=None, out=None)]
                used_s1 = np.concatenate((s1_hat,sampled_s1))
            else:
                sampled_s1=s1_tilde[torch.multinomial(torch.full([np.shape(s1_tilde)[0]], 1/float(np.shape(s1_tilde)[0])), int(len(s1_tilde)/2), replacement=False, generator=None, out=None)]
                used_s1 = np.concatenate((s1_hat,sampled_s1))
                sampled_s0=s0_tilde[torch.multinomial(torch.full([np.shape(s0_tilde)[0]], 1/float(np.shape(s0_tilde)[0])), int(len(s1_tilde)/2)-len(s0_hat)+len(s1_hat), replacement=False, generator=None, out=None)]
                used_s0 = np.concatenate((s0_hat,sampled_s0))
    else:
        if len(s0_tilde)>= int(len(s0_hat)/4) and len(s1_tilde)>= int(len(s1_hat)/4):
            sampled_s0=s0_hat[torch.multinomial(torch.full([np.shape(s0_hat)[0]], 1/float(np.shape(s0_hat)[0])), int(len(s1_tilde)), replacement=False, generator=None, out=None)]
            used_s0 = np.concatenate((s0_tilde,sampled_s0))
            sampled_s1=s1_hat[torch.multinomial(torch.full([np.shape(s1_hat)[0]], 1/float(np.shape(s1_hat)[0])), int(len(s0_tilde)), replacement=False, generator=None, out=None)]
            used_s1 = np.concatenate((s1_tilde,sampled_s1))
        else:
            if len(s0_hat)> len(s1_hat):
                sampled_s0=s0_hat[torch.multinomial(torch.full([np.shape(s0_hat)[0]],1/float(np.shape(s0_hat)[0])),int(len(s0_hat)/4), replacement=False, generator=None, out=None)]
                used_s0 = np.concatenate((s0_tilde,sampled_s0))
                sampled_s1=s1_hat[torch.multinomial(torch.full([np.shape(s1_hat)[0]], 1/float(np.shape(s1_hat)[0])),int(len(s0_hat)/4)-len(s1_tilde)+len(s0_tilde), replacement=False, generator=None, out=None)]
                used_s1 = np.concatenate((s1_tilde,sampled_s1))  
            else:
                sampled_s1=s1_hat[torch.multinomial(torch.full([np.shape(s1_hat)[0]], 1/float(np.shape(s1_hat)[0])),int(len(s1_hat)/4), replacement=False, generator=None, out=None)]
                used_s1 = np.concatenate((s1_tilde,sampled_s1))
                sampled_s0=s0_hat[torch.multinomial(torch.full([np.shape(s0_hat)[0]], 1/float(np.shape(s0_hat)[0])), int(len(s1_hat)/4)-len(s0_tilde)+len(s1_tilde), replacement=False, generator=None, out=None)]
                used_s0 = np.concatenate((s0_tilde,sampled_s0))
    sampled_nodes= np.unique(np.concatenate((used_s0, used_s1)))

    idx_map_n = {j: int(i) for i, j in enumerate(sampled_nodes)}

    #nodes_nonused = np.array(list(set(np.arange(len(labels))).difference(set(sampled_nodes))))                                                                                                                                                           
    used_ind1 = [i for i, elem in enumerate(edges[:, 0]) if elem in sampled_nodes]
    used_ind2 = [i for i, elem in enumerate(edges[:, 1]) if elem in sampled_nodes]
    intersect_ind = list(set(used_ind1) & set(used_ind2))
    edges = edges[intersect_ind, :]
    edges = np.array(list(map(idx_map_n.get, edges.flatten())),
                     dtype=int).reshape(edges.shape)
    features = features[sampled_nodes, :]
    sens = sens[sampled_nodes]
    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(sens.shape[0], sens.shape[0]),
                        dtype=np.float32)
    degs=np.sum(adj.toarray(), axis=1)+np.ones(len(np.sum(adj.toarray(), axis=1)))

    edges = torch.LongTensor(edges.T)
    return edges, features, sens, degs

def graph_attrs_ns(edges, sens):
    row, col = edges

    inter=np.where((sens[row] != sens[col]) == True)[0]

    node_ids_s0=np.where(sens==0)[0]
    node_ids_s1=np.where(sens==1)[0]

    s0=len(node_ids_s0)
    s1=len(node_ids_s1)

    edges=np.array(edges).T

    all_tildes=np.unique(edges[inter,:].flatten())
    s0_tilde=all_tildes[np.where(sens[all_tildes]==0)[0]]
    s0_hat= np.array(list(set(node_ids_s0).difference(set(s0_tilde))))

    s1_tilde=all_tildes[np.where(sens[all_tildes]==1)[0]]
    s1_hat= np.array(list(set(node_ids_s1).difference(set(s1_tilde))))
    return s0_tilde, s1_tilde, s0_hat, s1_hat

def graph_attrs_idel(edges, sens):
    row, col = edges

    inter=np.where((sens[row] != sens[col]) == True)[0]
    intra=np.where((sens[row] == sens[col]) == True)[0]
    
    edges=np.array(edges).T
    
    intra_s0=len(np.where(sens[edges[intra,0]]==0)[0])
    intra_s1=len(intra)-intra_s0

    return float(len(inter))/(2*intra_s0), float(len(inter))/(2*intra_s1)

def stats(edges, sens):
    row, col = edges

    intra=np.where((sens[row] == sens[col]) == True)[0]
    inter=np.where((sens[row] != sens[col]) == True)[0]

    node_ids_s0=np.where(sens==0)[0]
    node_ids_s1=np.where(sens==1)[0]

    s0=len(node_ids_s0)
    s1=len(node_ids_s1)

    edges=np.array(edges).T

    all_tildes=np.unique(edges[inter,:].flatten())
    s0_tilde=all_tildes[np.where(sens[all_tildes]==0)[0]]
    s0_hat= np.array(list(set(node_ids_s0).difference(set(s0_tilde))))

    s1_tilde=all_tildes[np.where(sens[all_tildes]==1)[0]]
    s1_hat= np.array(list(set(node_ids_s1).difference(set(s1_tilde))))

    intra_s0=len(np.where(sens[edges[intra,0]]==0)[0])
    intra_s1=len(intra)-intra_s0


    gamma1=np.absolute(1-(float(len(s0_tilde))/s0+float(len(s1_tilde))/s1)) 
    ratios1=np.zeros(s0)
    ratios2=np.zeros(s1)
    for i, n in enumerate(node_ids_s0):
        x1=len(np.where(edges[inter,0]==n)[0])
        x2=len(np.where(edges[inter,1]==n)[0])
        o1=len(np.where(edges[intra,0]==n)[0])
        o2=len(np.where(edges[intra,1]==n)[0])
        if (x1+x2+o1+o2)==0:
            ratios1[i]=0
        else:
            ratios1[i]=(x1+x2)/(x1+x2+o1+o2)
    for i, n in enumerate(node_ids_s1):
        x1=len(np.where(edges[inter,0]==n)[0])
        x2=len(np.where(edges[inter,1]==n)[0])
        o1=len(np.where(edges[intra,0]==n)[0])
        o2=len(np.where(edges[intra,1]==n)[0])
        if (x1+x2+o1+o2)==0:
            ratios2[i]=0
        else:
            ratios2[i]=(x1+x2)/(x1+x2+o1+o2)
        
    gamma2=1-2*min((ratios1).mean(),(ratios2).mean())
    print('gamma1 is: ', gamma1)
    print('gamma2 is: ', gamma2)
    print('resulted intra for s0: ',intra_s0)
    print('resulted intra for s1: ',intra_s1)
    print('resulted inter edges: ',len(inter))
    print('resulted s0_tilde: ', len(s0_tilde))
    print('resulted s1_tilde: ', len(s1_tilde))
    print('resulted s0_hat: ', len(s0_hat))
    print('resulted s1_hat: ', len(s1_hat))

def intra_pres(edges, features, sens):
    row, col = edges

    intra=np.where((sens[row] == sens[col]) == True)[0]
    inter=np.where((sens[row] != sens[col]) == True)[0]

    edges=np.array(edges).T

    intra_edges=edges[intra,:]
    inter_edges=edges[inter,:]

    node_ids_s0=np.where(sens==0)[0]
    node_ids_s1=np.where(sens==1)[0]

    s0=len(node_ids_s0)
    s1=len(node_ids_s1)

    deg_max_s0=max(degs[node_ids_s0])
    deg_max_s1=max(degs[node_ids_s1])

    intra_0=len(np.where(sens[intra_edges[:,0]]==0)[0])
    intra_1=len(np.where(sens[intra_edges[:,0]]==1)[0])

    print('All intra edges: ', len(intra))
    print('All inter edges: ', len(inter))

    print('All intra edges in 0: ', intra_0)
    print('All inter edges in 1: ', intra_1)

    print('required intra in 0: ', (s0*deg_max_s0)/4)
    print('required intra in 1: ', (s1*deg_max_s1)/4)
def flatten(t):
    return [item for sublist in t for item in sublist]

def train_grace(model: Model, x, edge_index):
    model.train()
    optimizer.zero_grad()
    edge_index_1 = dropout_adj(edge_index, p=drop_edge_rate_1)
    edge_index_2 = dropout_adj(edge_index, p=drop_edge_rate_2)
    x_1 = drop_feature(x, drop_feature_rate_1)
    x_2 = drop_feature(x, drop_feature_rate_2)
    z1 = model(x_1, edge_index_1)
    z2 = model(x_2, edge_index_2)

    loss = model.loss(z1, z2, batch_size=0)
    loss.backward()
    optimizer.step()

    return loss.item()

def test(model: Model, x, edge_index, y,sens):
    model.eval()
    z = model(x, edge_index)
    label_classification(z, y, sens, ratio=0.1)
    #sens_classification(z, sens, ratio=0.1)
def test_lp(model: Model, x, edge_index, edges_t, sens,neg_edges_tr, neg_edges_t):
    model.eval()
    z = model(x, edge_index)
    results = link_prediction(z, edge_index, edges_t, neg_edges_tr, neg_edges_t, sens) 
    statistics = {}
    for key in results[0].keys():
        values = [r[key] for r in results]
        statistics[key] = {
            'mean': np.mean(values),
            'std': np.std(values)}
    print_statistics(statistics, 'Link prediction')
    #sens_classification(z, sens, ratio=0.1)
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--no-cuda', action='store_true', default=True,
                        help='Disables CUDA training.')
    parser.add_argument('--dataset', type=str, default='pokec')
    parser.add_argument('--method', type=str, default='grace')
    parser.add_argument('--gpu_id', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=400,
                        help='Number of epochs to train.')
    parser.add_argument('--lr',type=float, default=0.0005)
    parser.add_argument('--weight-decay', type=float, default=1e-5,
                        help='Weight decay (L2 loss on parameters).')
    parser.add_argument('--num-hidden', type=int, default=256)
    parser.add_argument('--num-proj-hidden', type=int, default=256)
    parser.add_argument('--drop-edge-rate-1', type=float, default=0.3)
    parser.add_argument('--drop-edge-rate-2', type=float, default=0.4)
    parser.add_argument('--drop-feature-rate-1', type=float, default=0.1)
    parser.add_argument('--drop-feature-rate-2', type=float, default=0.0)
    parser.add_argument('--pi', type=float, default=1)
    parser.add_argument('--pmin', type=float, default=0.5)
    parser.add_argument('--tau',type=float,default=0.4)
    parser.add_argument('--layer-num',type=int, default=2)
    parser.add_argument('--activation',type=str,default='relu')
    parser.add_argument('--base-model',type=str, default='GCNConv')
    parser.add_argument("--seed",type=int, default=39788, help='Random seed.')
    
    args = parser.parse_known_args()[0]
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    torch.set_num_threads(8)   
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.use_deterministic_algorithms(True)
   
    learning_rate = args.lr
    num_hidden = args.num_hidden
    num_proj_hidden = args.num_proj_hidden
    activation = ({'relu': F.relu, 'prelu': nn.PReLU()})[args.activation]
    base_model = ({'GCNConv': GCNConv})[args.base_model]
    num_layers = args.layer_num
    drop_edge_rate_1 = args.drop_edge_rate_1
    drop_edge_rate_2 = args.drop_edge_rate_2
    drop_feature_rate_1 = args.drop_feature_rate_1
    drop_feature_rate_2 = args.drop_feature_rate_2
    tau = args.tau
    method=args.method
    pi=args.pi
    pmin=args.pmin

    num_epochs = args.epochs
    weight_decay = args.weight_decay

    
    if args.dataset == 'pokec':
        dataset = 'region_job'
        sens_attr = "region"
        predict_attr = "I_am_working_in_field"
        path = "datasets/pokec_dataset/"
        edges,features,labels,sens, degs_org=load_pokec(dataset, sens_attr, predict_attr, path)
        features = feature_norm(features)
    elif args.dataset =='pokec2':
        dataset = 'region_job_2'
        sens_attr = "region"
        predict_attr = "I_am_working_in_field"
        path = "datasets/pokec_dataset/"
        edges,features,labels,sens, degs_org=load_pokec(dataset, sens_attr, predict_attr, path)
        features = feature_norm(features)
    elif args.dataset == 'fbberkeley':
        dset='Berkeley13.mat'
        path='datasets/socfb-Berkeley13'
        edges_org, features, sens=load_fb(path, dset)
        name='berkeley'
        features = feature_norm(features)
    elif args.dataset == 'fbucsd':
        path='datasets/socfb-UCSD34'
        dset='UCSD34.mat'
        edges_org, features, sens=load_fb(path, dset)
        name='ucsd'
        features = feature_norm(features)
    remove_probs1_org, remove_probs2_org=topology_aware_disc(features, sens, drop_feature_rate_1, drop_feature_rate_2)    
    #print('The stats for the original data will be ready: ')
    #node_stats(torch.LongTensor(edges_org.T), sens)

    if args.dataset[:2]=='fb':
        repeat=5
    else:
        repeat=1
    results=[]
    for r in range(repeat):
        if args.dataset == 'fbberkeley' or  args.dataset == 'fbucsd':
            edge_idx=np.load('lp_orders/'+name+'_edge_order'+str(r+1)+'.npy')
  
            edges=edges_org[edge_idx,:]
            num_edges=np.shape(edges)[0]
            edges_train = edges[:int(0.9*num_edges),:]
                                                                                                            
            edges_test = edges[int(0.9*num_edges):,:]
            edges = torch.LongTensor(edges_train.T)
            edges_t = torch.LongTensor(edges_test.T)
            neg_edges_tr = negative_sampling(
            edge_index=edges,
            num_nodes=len(sens),
            num_neg_samples=edges.size(1),
                    )
            neg_edges_tr=np.array(neg_edges_tr).T
            neg_edges_t = negative_sampling(
            edge_index=edges_t,
            num_nodes=len(sens),
            num_neg_samples=edges_t.size(1),
                    )
            neg_edges_t=np.array(neg_edges_t).T
            device = torch.device('cuda' if args.cuda else 'cpu')
            edges = edges.to(device)
            edges_t = edges_t.to(device)
            features= features.to(device)
            sens=sens.to(device)
        
        else:   
            device = torch.device('cuda' if args.cuda else 'cpu')
            edges = edges.to(device)
            features= features.to(device)
            labels=labels.to(device)
            sens=sens.to(device)
        print('original stats are: ')
        stats(edges,sens)
        if method=='fclgraph' or method=='fclgraph_old' or method=='node_sample' or method=='ns_edge_del':
            s0_tilde, s1_tilde, s0_hat, s1_hat=graph_attrs_ns(edges, sens)
        else:
            r1, r2 = graph_attrs_idel(edges, sens)


        encoder = Encoder(features.shape[1], num_hidden, activation,
                          base_model=base_model, k=num_layers).to(device)
        model = Model(encoder, num_hidden, num_proj_hidden, tau).to(device)
        optimizer = torch.optim.Adam(
            model.parameters(), lr=learning_rate, weight_decay=weight_decay)

        start = t()
        prev = start
        for epoch in range(1, num_epochs + 1):
            if method=='grace':
                loss = train_grace(model, features, edges)
            elif method=='fclgraph':
                model.train()
                optimizer.zero_grad()
                s_edges, s_features, s_sens, degs =node_sample_uni(edges, features, sens, s0_tilde, s1_tilde, s0_hat, s1_hat)
                #print('The stats after node sampling will be ready: ')
                stats(s_edges, s_sens)
                remove_probs1, remove_probs2=topology_aware_disc(s_features, s_sens, drop_feature_rate_1, drop_feature_rate_2)
                                                              
                r1, r2 = graph_attrs_idel(s_edges, s_sens)
            
                s3_edges_1=dropout_adj_fair(s_edges, s_sens, r1, r2, pi, pmin, 1)
                s3_edges_2=dropout_adj_fair(s_edges, s_sens, r1, r2, pi, pmin, 1)
                #print('The stats after node sampling+edge del  will be ready: ')
                #stats(s3_edges_1, s_sens)
                s4_edges_1=inter_add_uniform(s3_edges_1,s_sens)
                s4_edges_2=inter_add_uniform(s3_edges_2,s_sens)
                #print('The stats after node sampling+edge del+iadd  will be ready: ')
                #stats(s4_edges_1, s_sens)

                x_1=drop_feature_topology_aware(s_features,remove_probs1)
                x_2=drop_feature_topology_aware(s_features,remove_probs2)

                z1 = model(x_1, s4_edges_1)
                z2 = model(x_2, s4_edges_2)

                loss = model.loss(z1, z2, batch_size=0)
                loss.backward()
                optimizer.step()
            elif method == 'ns_edge_del':
                model.train()
                optimizer.zero_grad()
                s_edges, s_features, s_sens, degs =node_sample_uni(edges, features, sens, s0_tilde, s1_tilde, s0_hat, s1_hat)
            
            #print('The stats after node sampling will be ready: ')
            #node_stats(s_edges, s_sens, degs)
            
                r1, r2 = graph_attrs_idel(s_edges, s_sens)
            
                s2_edges_1=dropout_adj_fair(s_edges, s_sens, r1, r2, pi, pmin, 1)
                s2_edges_2=dropout_adj_fair(s_edges, s_sens, r1, r2, pi, pmin, 1)
                s3_edges_1=inter_add_uniform(s2_edges_1,s_sens)
                s3_edges_2=inter_add_uniform(s2_edges_2,s_sens)

            #print('The stats after edge deletion will be ready: ')
            #node_stats(s2_edges_1, s_sens, degs)

                z1 = model(s_features, s3_edges_1)
                z2 = model(s_features, s3_edges_2)
                loss = model.loss(z1, z2, batch_size=0)
                loss.backward()
                optimizer.step()
            elif method=='edge_del':               
                model.train()
                optimizer.zero_grad()
                remove_probs1, remove_probs2=topology_aware_disc(features, sens, drop_feature_rate_1, drop_feature_rate_2)

                s2_edges_1=dropout_adj_fair(edges, sens, r1, r2, pi, pmin, 1)
                s2_edges_2=dropout_adj_fair(edges, sens, r1, r2, pi, pmin, 1)
                #print('The stats after edge deletion will be ready: ')
                #stats(s2_edges_1, sens)
            
                s3_edges_1=inter_add_uniform(s2_edges_1,sens)
                s3_edges_2=inter_add_uniform(s2_edges_2,sens)
                #print('The stats after edge deletion+edge addition will be ready: ')
                #stats(s3_edges_1, sens)
                x_1=drop_feature_topology_aware(features,remove_probs1)
                x_2=drop_feature_topology_aware(features,remove_probs2)

                z1 = model(x_1, s3_edges_1)
                z2 = model(x_2, s3_edges_2)
                #z1 = model(x_1, s2_edges_1)                                                                                                                                                                                                               
                #z2 = model(x_2, s2_edges_2)
                loss = model.loss(z1, z2, batch_size=0)
                loss.backward()
                optimizer.step()
            elif method == 'node_sample':
                model.train()
                optimizer.zero_grad()
                s_edges, s_features, s_sens, degs =node_sample_uni(edges, features, sens, s0_tilde, s1_tilde, s0_hat, s1_hat)
                #print('The stats after node sampling will be ready: ')
                #node_stats(s_edges, s_sens, degs)
                remove_probs1, remove_probs2=topology_aware_disc(s_features, s_sens, drop_feature_rate_1, drop_feature_rate_2)
                #print('The stats after node sampling will be ready: ')
                #stats(s_edges, s_sens)
                s2_edges_1=inter_add_uniform(s_edges,s_sens)
                s2_edges_2=inter_add_uniform(s_edges,s_sens)
                #print('The stats after ns+edge addition will be ready: ')
                #stats(s2_edges_1, s_sens)
                x_1=drop_feature_topology_aware(s_features,remove_probs1)
                x_2=drop_feature_topology_aware(s_features,remove_probs2)

                z1 = model(x_1, s2_edges_1)
                z2 = model(x_2, s2_edges_2)
                #z1 = model(x_1, s_edges)
                #z2 = model(x_2, s_edges)
                loss = model.loss(z1, z2, batch_size=0)
                loss.backward()
                optimizer.step()
            elif method=='fclgraph_old':
                model.train()
                optimizer.zero_grad()
                s_edges, s_features, s_sens, degs =node_sample_uni(edges, features, sens, s0_tilde, s1_tilde, s0_hat, s1_hat)
                #print('The stats after node sampling will be ready: ')
                #node_stats(s_edges, s_sens, degs)
                remove_probs1, remove_probs2=topology_aware_disc(s_features, s_sens, drop_feature_rate_1, drop_feature_rate_2)
                                                              
                r1, r2 = graph_attrs_idel(s_edges, s_sens)

                s3_edges_1=dropout_adj_fair(s_edges, s_sens, r1, r2, pi, pmin, 1)
                s3_edges_2=dropout_adj_fair(s_edges, s_sens, r1, r2, pi, pmin, 1)
            
                #print('The stats after NS + edge deletion: ')
                #node_stats(s3_edges_1, s_sens, degs)

                x_1=drop_feature_topology_aware(s_features,remove_probs1)
                x_2=drop_feature_topology_aware(s_features,remove_probs2)

                z1 = model(x_1, s3_edges_1)
                z2 = model(x_2, s3_edges_2)

                loss = model.loss(z1, z2, batch_size=0)
                loss.backward()
                optimizer.step()    
            elif method=='fm':
                model.train()
                optimizer.zero_grad()

                remove_probs1, remove_probs2=topology_aware_disc(features, sens, drop_feature_rate_1, drop_feature_rate_2)
                x_1=drop_feature_topology_aware(features,remove_probs1)
                x_2=drop_feature_topology_aware(features,remove_probs2)

                z1 = model(x_1, edges)
                z2 = model(x_2, edges)

                loss = model.loss(z1, z2, batch_size=0)
                loss.backward()
                optimizer.step()
            elif method=='fm_iadd':
                model.train()
                optimizer.zero_grad()

                remove_probs1, remove_probs2=topology_aware_disc(features, sens, drop_feature_rate_1, drop_feature_rate_2)
                x_1=drop_feature_topology_aware(features,remove_probs1)
                x_2=drop_feature_topology_aware(features,remove_probs2)

                s2_edges_1=inter_add_uniform(edges,sens)                                                                                                                                                                                               
                s2_edges_2=inter_add_uniform(edges,sens)
                #print('The stats after edge addition: ')                                                                                                                                                                                              
                #stats(s2_edges_1, sens)
                z1 = model(x_1, s2_edges_1)
                z2 = model(x_2, s2_edges_2)

                loss = model.loss(z1, z2, batch_size=0)
                loss.backward()
                optimizer.step()

            now = t()
            print(f'(T) | Epoch={epoch:03d}, loss={loss:.4f}, '
                  f'this epoch {now - prev:.4f}, total {now - start:.4f}')
            prev = now

        print("=== Final ===")
        if args.dataset == 'fbberkeley' or args.dataset == 'fbucsd':
            model.eval()
            z = model(features, edges)
            np.save('datasets/'+args.dataset+str(r)+'_embeds.npy',z.detach().numpy())
            results.append(link_prediction(z, edges, edges_t, neg_edges_tr, neg_edges_t, sens))
            for key in results[0].keys():
                print(key+' :',results[r][key] )
        else:
            test(model, features, edges, labels,sens)
    if args.dataset == 'fbberkeley' or args.dataset == 'fbucsd':
        statistics = {}

        for key in results[0].keys():
            values = [r[key] for r in results]
            statistics[key] = {
                'mean': np.mean(values),
                'std': np.std(values)}

        print_statistics(statistics, 'Link prediction')
        #sens_classification(z, sens, ratio=0.1)
