import numpy as np
import random
random.seed(10)
import torch
g_seed=39788
np.random.seed(g_seed)
torch.manual_seed(g_seed)
import os
import os.path as osp
import random
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
import numpy as np
import argparse
import torch
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, accuracy_score
import random
from utils import
    get_link_labels,
    prediction_fairness,
)

from torch_geometric.utils import train_test_split_edges
device='cpu'
g_seed=39788
torch.set_num_threads(8)
np.random.seed(g_seed)
torch.manual_seed(g_seed)
torch.use_deterministic_algorithms(True)

parser = argparse.ArgumentParser()

parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--method', type=str, default='gcn')

args = parser.parse_known_args()[0]
dataset=args.dataset
if args.method=='gcn':
    class GNN(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GNN, self).__init__()
            self.conv1 = GCNConv(in_channels, 128)
            self.conv2 = GCNConv(128, out_channels)

        def encode(self, x, pos_edge_index):
            x = F.relu(self.conv1(x, pos_edge_index))
            x = self.conv2(x, pos_edge_index)
            return x

        def decode(self, z, pos_edge_index, neg_edge_index):
            edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
            logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
            return logits, edge_index
elif args.method=='gat':
    class GNN(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GNN, self).__init__()
            self.conv1 = GATConv(in_channels, 128)
            self.conv2 = GATConv(128, out_channels)

        def encode(self, x, pos_edge_index):
            x = F.relu(self.conv1(x, pos_edge_index))
            x = self.conv2(x, pos_edge_index)
            return x

        def decode(self, z, pos_edge_index, neg_edge_index):
            edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
            logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
            return logits, edge_index
elif args.method=='sage':
    class GNN(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GNN, self).__init__()
            self.conv1 = SAGEConv(in_channels, 128)
            self.conv2 = SAGEConv(128, out_channels)

        def encode(self, x, pos_edge_index):
            x = F.relu(self.conv1(x, pos_edge_index))
            x = self.conv2(x, pos_edge_index)
            return x

        def decode(self, z, pos_edge_index, neg_edge_index):
            edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
            logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
            return logits, edge_index



def dropout_adj_fair(edge_index, y, sens, r, pk=0.9, pmin=0.4, pmax=1):
    row, col = edge_index

    ma= edge_index.new_full((row.size(0),), pk, dtype=torch.float)

    ma[np.where((y[row] != y[col]) == True)[0]]= pk
    for i,s in enumerate(np.unique(y)):
        ma[np.where((torch.logical_and(y[row] == y[col], y[row]==s*torch.ones(row.size(0)))) == True)[0]] = min(pmax,max(pk*r[s],pmin))
    ma = torch.bernoulli(ma).to(torch.bool)

    row, col = filter_adj_fair(row, col, ma)
    edge_index = torch.stack([row, col], dim=0)

    return edge_index

def filter_adj_fair(row, col, mask):
    return row[mask], col[mask]

def graph_attrs_idel(edges, y, sens):
    row, col = edges

    inter=np.where(sens==1)[0]
    intra=np.where(sens==0)[0]
    #print('number of inter edges is: ', len(inter))
    #print('number of intra edges is: ', len(intra))
    #print('total edge_num: ', len(sens))
    #print('sensitive attribute num: ',len(np.unique(y)))
    #print('total node num: ', len(y))
    edges=np.array(edges).T
    r={}
    #for i,s in enumerate(np.unique(y)):
    #    for j in range(i+1, len(np.unique(y))):
    #        ids1=np.where(y[edges[inter,0]]==s)[0]
    #        ids2=np.where(y[edges[inter,1]]==np.unique(y)[j])[0]
    #        print('the number of inter between '+str(s)+' and '+str(y[j])+'is ', len(list(set(ids1).intersection(set(ids2)))))
    #        print('the number of intra edges in '+str(s)+'is: ',len(np.where(y[edges[intra,0]]==s)[0]))
    for i,s in enumerate(np.unique(y)):
        #print('the number of intra edges with s '+str(s)+': ',len(np.where((torch.logical_and(y[row] == y[col], y[row]==s*torch.ones(row.size(0)))) == True)[0]))
        a=(len(np.where(y[edges[intra,0]]==s)[0]))
        if a>0:
            r[s]=float((len(np.where(y[edges[inter,1]]==s)[0])+len(np.where(y[edges[inter,0]]==s)[0])))/(2*len(np.where(y[edges[intra,0]]==s)[0]))
            #print('r_'+str(s)+': ',r[s])
            
        else:
            print('check')
    return r
    
path = osp.join(osp.dirname(osp.realpath('__file__')), "..", "data", dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())

test_seeds = [0,1,2,3,4,5]
#test_seeds = [0,1]                                                                                                                                                                                                                                         
acc_auc = []
fairness = []
budget=[]
for random_seed in test_seeds:
    np.random.seed(random_seed)
    data = dataset[0]
    protected_attr = data.y
    data.train_mask = data.val_mask = data.test_mask = data.y = None
    data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)
    data = data.to(device)

    num_classes = len(np.unique(protected_attr))
    N = data.num_nodes
    
    
    epochs = 101
    model = GNN(data.num_features, 128).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    

    Y = torch.LongTensor(protected_attr).to(device)
    
    sens= (
        Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]
    ).to(device)
    
    
    r = graph_attrs_idel(data.train_pos_edge_index, Y,sens)
    
    best_val_perf = test_perf = 0
    
    b=0
    for epoch in range(1, epochs):
        neg_edges_tr = negative_sampling(
            edge_index=data.train_pos_edge_index,
            num_nodes=N,
            num_neg_samples=data.train_pos_edge_index.size(1) // 2,
                    ).to(device)

        #r = graph_attrs_idel(data.train_pos_edge_index, Y,sens)
        
        model.train()
        optimizer.zero_grad()
    
        new_edges=dropout_adj_fair(data.train_pos_edge_index, Y,sens, r, 1, 0, 1)
        #print(new_edges.size())
        b=b+new_edges.size(1)
        z = model.encode(data.x, new_edges)
        link_logits, _ = model.decode(
            z, new_edges, neg_edges_tr
        )
        tr_labels = get_link_labels(
            new_edges, neg_edges_tr
        ).to(device)
        
        loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)
        loss.backward()
        optimizer.step()

        # EVALUATION
        model.eval()
        perfs = []
        for prefix in ["val", "test"]:
            pos_edge_index = data[f"{prefix}_pos_edge_index"]
            neg_edge_index = data[f"{prefix}_neg_edge_index"]
            with torch.no_grad():
                z = model.encode(data.x,data.train_pos_edge_index)
                link_logits, edge_idx = model.decode(z, pos_edge_index, neg_edge_index)
            link_probs = link_logits.sigmoid()
            link_labels = get_link_labels(pos_edge_index, neg_edge_index)
            auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())
            perfs.append(auc)

        val_perf, tmp_test_perf = perfs
        if val_perf > best_val_perf:
            best_val_perf = val_perf
            test_perf = tmp_test_perf
        if epoch%10==0:
            log = "Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}"
            print(log.format(epoch, loss, best_val_perf, test_perf))
    budget.append(float(b)/epochs)
    # FAIRNESS
    auc = test_perf
    cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]
    best_acc = 0
    best_cut = 0.5
    for i in cut:
        acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)
        if acc > best_acc:
            best_acc = acc
            best_cut = i
    f = prediction_fairness(
        edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()
    )
    acc_auc.append([best_acc * 100, auc * 100])
    fairness.append([x * 100 for x in f])
np.save('budget_ours_'+args.dataset+'_'+args.method+'.npy',np.array(budget))
ma = np.mean(np.asarray(acc_auc), axis=0)
mf = np.mean(np.asarray(fairness), axis=0)

sa = np.std(np.asarray(acc_auc), axis=0)
sf = np.std(np.asarray(fairness), axis=0)

print(f"ACC: {ma[0]:2f} +- {sa[0]:2f}")
print(f"AUC: {ma[1]:2f} +- {sa[1]:2f}")

print(f"DP mix: {mf[0]:2f} +- {sf[0]:2f}")
print(f"EoP mix: {mf[1]:2f} +- {sf[1]:2f}")
