import os
import random
random.seed(10)
import os.path as osp
import random
import argparse
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
import numpy as np
import torch
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, accuracy_score
import random
from utils import (
    get_link_labels,
    prediction_fairness,
)

from torch_geometric.utils import train_test_split_edges
device='cpu'
g_seed=39788
torch.set_num_threads(8)
np.random.seed(g_seed)
torch.manual_seed(g_seed)
torch.use_deterministic_algorithms(True)

parser = argparse.ArgumentParser()

parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--method', type=str, default='gcn')

args = parser.parse_known_args()[0]
dataset=args.dataset
if args.method=='gcn':
    class GNN(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GNN, self).__init__()
            self.conv1 = GCNConv(in_channels, 128)
            self.conv2 = GCNConv(128, out_channels)

        def encode(self, x, pos_edge_index):
            x = F.relu(self.conv1(x, pos_edge_index))
            x = self.conv2(x, pos_edge_index)
            return x

        def decode(self, z, pos_edge_index, neg_edge_index):
            edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
            logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
            return logits, edge_index
elif args.method=='gat':
    class GNN(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GNN, self).__init__()
            self.conv1 = GATConv(in_channels, 128)
            self.conv2 = GATConv(128, out_channels)

        def encode(self, x, pos_edge_index):
            x = F.relu(self.conv1(x, pos_edge_index))
            x = self.conv2(x, pos_edge_index)
            return x
        def decode(self, z, pos_edge_index, neg_edge_index):
            edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
            logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
            return logits, edge_index
elif args.method=='sage':
    class GNN(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super(GNN, self).__init__()
            self.conv1 = SAGEConv(in_channels, 128)
            self.conv2 = SAGEConv(128, out_channels)

        def encode(self, x, pos_edge_index):
            x = F.relu(self.conv1(x, pos_edge_index))
            x = self.conv2(x, pos_edge_index)
            return x

        def decode(self, z, pos_edge_index, neg_edge_index):
            edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
            logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
            return logits, edge_index
path = osp.join(osp.dirname(osp.realpath('__file__')), "..", "data", dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())

test_seeds = [0,1,2,3,4,5]
#test_seeds = [0,1]                                                                                                                                                                                                                                         
acc_auc = []
fairness = []
budget=np.load('budget_ours_'+args.dataset+'_'+args.method+'.npy')
for random_seed in test_seeds:
    np.random.seed(random_seed)
    data = dataset[0]
    protected_attribute = data.y
    data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)
    data = data.to(device)

    num_classes = len(np.unique(protected_attribute))
    N = data.num_nodes
    
    
    epochs = 101
    model = GNN(data.num_features, 128).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    

    Y = torch.LongTensor(protected_attribute).to(device)
    Y_aux = (
        Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]
    ).to(device)
    
    best_val_perf = test_perf = 0
    for epoch in range(1, epochs):
        neg_edges_tr = negative_sampling(
            edge_index=data.train_pos_edge_index,
            num_nodes=N,
            num_neg_samples=data.train_pos_edge_index.size(1) // 2,
                    ).to(device)
        
        arr = np.arange(data.train_pos_edge_index.size(1))
        np.random.shuffle(arr)
        used_edges=arr[:int(budget[random_seed])]
        model.train()
        optimizer.zero_grad()

        z = model.encode(data.x, data.train_pos_edge_index[:,used_edges])
        link_logits, _ = model.decode(
            z, data.train_pos_edge_index[:,used_edges], neg_edges_tr
        )
        tr_labels = get_link_labels(
            data.train_pos_edge_index[:,used_edges], neg_edges_tr
        ).to(device)
        
        loss = F.binary_cross_entropy_with_logits(link_logits, tr_labels)
        loss.backward()
        optimizer.step()

        # EVALUATION
        model.eval()
        perfs = []
        for prefix in ["val", "test"]:
            pos_edge_index = data[f"{prefix}_pos_edge_index"]
            neg_edge_index = data[f"{prefix}_neg_edge_index"]
            with torch.no_grad():
                z = model.encode(data.x, data.train_pos_edge_index)
                link_logits, edge_idx = model.decode(z, pos_edge_index, neg_edge_index)
            link_probs = link_logits.sigmoid()
            link_labels = get_link_labels(pos_edge_index, neg_edge_index)
            auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())
            perfs.append(auc)

        val_perf, tmp_test_perf = perfs
        if val_perf > best_val_perf:
            best_val_perf = val_perf
            test_perf = tmp_test_perf
        if epoch%10==0:
            log = "Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}"
            print(log.format(epoch, loss, best_val_perf, test_perf))
    # FAIRNESS
    auc = test_perf
    cut = [0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75]
    best_acc = 0
    best_cut = 0.5
    for i in cut:
        acc = accuracy_score(link_labels.cpu(), link_probs.cpu() >= i)
        if acc > best_acc:
            best_acc = acc
            best_cut = i
    f = prediction_fairness(
        edge_idx.cpu(), link_labels.cpu(), link_probs.cpu() >= best_cut, Y.cpu()
    )
    acc_auc.append([best_acc * 100, auc * 100])
    fairness.append([x * 100 for x in f])


ma = np.mean(np.asarray(acc_auc), axis=0)
mf = np.mean(np.asarray(fairness), axis=0)

sa = np.std(np.asarray(acc_auc), axis=0)
sf = np.std(np.asarray(fairness), axis=0)

print(f"ACC: {ma[0]:2f} +- {sa[0]:2f}")
print(f"AUC: {ma[1]:2f} +- {sa[1]:2f}")

print(f"DP mix: {mf[0]:2f} +- {sf[0]:2f}")
print(f"EoP mix: {mf[1]:2f} +- {sf[1]:2f}")
