from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np
import os
import torch
import torch.nn as nn  
import torch.nn.functional as F
import torch.optim as optim
import shap
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager
from utils import *
from models import GCN_pia, GCN_Net, GAT_Net, SGC_Net, SAGE_Net, GPRGNN, GCN_PLUS_Net, GAT_PLUS_Net, SGC_PLUS_Net, NLGCN, NLGAT, NLMLP
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import AttributedGraphDataset, LastFMAsia
from torch_geometric.transforms import RandomNodeSplit
from sklearn.metrics import roc_auc_score, recall_score, precision_score
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import torch_geometric.transforms as T
import pandas as pd
import pickle as pkl
import networkx as nx
import torch.distributions.laplace as laplace
from deepset import HalfNLHconv, MLP

def edge_index_to_adjacency_matrix(edge_index, num_nodes):  
    adjacency_matrix = torch.zeros(num_nodes, num_nodes, dtype=torch.float32)  
    adjacency_matrix[edge_index[0], edge_index[1]] = 1  
    adjacency_matrix[edge_index[1], edge_index[0]] = 1  
    return adjacency_matrix

def load_data(dataset_name, device):
    if dataset_name == "Cora":
        dataset = Planetoid(root='/data', name='Cora')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "CiteSeer":
        dataset = Planetoid(root='/data', name='CiteSeer')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        print(data.adj_t.shape)
        data = data.to(device)
    elif dataset_name == "PubMed":
        dataset = Planetoid(root='/data', name='PubMed')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "computers":
        dataset = Amazon(root='/data/computers', name='computers')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "photo":
        dataset = Amazon(root='/data/photo', name='photo')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "chameleon":
        preProcDs = WikipediaNetwork(root='/data', name='chameleon', geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(root='/data', name='chameleon', geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "squirrel":
        preProcDs = WikipediaNetwork(root='/data', name='squirrel', geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(root='/data', name='squirrel', geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "texas":
        dataset = WebKB(root='/data', name = "Texas")
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "ogbn_arxiv":
        dataset = PygNodePropPredDataset(root='/data', name='ogbn-arxiv', transform=T.ToSparseTensor())
        data = dataset[0]
        data.adj_t = data.adj_t.to_symmetric()
        edge_index = data.adj_t.coo()
        data.edge_index = torch.stack([edge_index[0], edge_index[1]], dim=0)
        data.y = data.y.squeeze(1)
        data = data.to(device)
    elif 'cSBM' in dataset_name:
        path = '/data/' 
        dataset = dataset_ContextualSBM(path, name=dataset_name)
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif 'Facebook' in dataset_name:
            path = '/data/Facebook' 
            dataset = AttributedGraphDataset(path, name=dataset_name)
            data = dataset[0]
            data.y = data.y.max(1)[1]
            data = data.to(device)
    elif 'LastFM' in dataset_name:
        path = '/data/LastFM' 
        dataset = LastFMAsia(path)
        data = dataset[0].to(device)
    return dataset, data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
alpha_soft_set = dict()
######### This value need to set by yourself
alpha_soft_set["Cora"] = 0.3
alpha_soft_set["CiteSeer"] = 0.22 
alpha_soft_set["PubMed"] = 0.55
alpha_soft_set["computers"] = 0.32
alpha_soft_set["photo"] = 0.25 
alpha_soft_set["texas"] = 0.02
alpha_soft_set["chameleon"] = 0.35
alpha_soft_set["squirrel"] = 0.93
alpha_soft_set["ogbn_arxiv"] = 0.8
alpha_soft_set["Facebook"] = 0.80
alpha_soft_set["LastFM"] = 0.60
# Training settings
parser = argparse.ArgumentParser() 
parser.add_argument('--Datasets', type=list, default=['CiteSeer', "Facebook", "LastFM", 'chameleon'])
parser.add_argument('--model_name', type=list, default=['MLP', 'MLP', 'MLP', 'MLP'])
parser.add_argument('--no-cuda', action='store_true', default=True,
                    help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
                    help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200,
                    help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=1e-3,#5e-2,#0.02,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4,#0,#5e-5,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=64,#32
                    help='Number of hidden units.')
parser.add_argument('--heads', type=int, default=8,
                    help='Number of attension head.')
parser.add_argument('--num_layers', type=int, default=3,
                    help='Number of layers.')
parser.add_argument('--output_heads', type=int, default=3)
parser.add_argument('--Init', type=str, default='PPR')
parser.add_argument('--K', type=int, default=10)
parser.add_argument('--ppnp', type=str, default='GPR_prop')
parser.add_argument('--Gamma', type=float, default=None)
parser.add_argument('--dprate', type=float, default=0.5)#0.7)#
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--disp', type=int, default=10)
parser.add_argument('--dropout', type=float, default=0,#0.5,#
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--dropout1', type=float, default=0.5)
parser.add_argument('--dropout2', type=float, default=0.5)
parser.add_argument('--kernel', type=int, default=5)
parser.add_argument('--testratio', type=float, default=0.2)
parser.add_argument('--upper', type=float, default=0.8)
parser.add_argument('--r', type=float, default=0.4)
parser.add_argument('--b', type=float, default=0.3)

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

test_loss_acc=[]
train_loss_acc=[]
result_train=[]
result_test=[]

class LaplaceNoiseGenerator:
    def __init__(self, noise_multiplier):
        self.noise_multiplier = noise_multiplier

    def __call__(self, reference):
        scale = self.noise_multiplier
        laplace_dist = torch.distributions.Laplace(loc=0.0, scale=args.b)
        noise = laplace_dist.sample(reference.shape).to(reference.device)
        return noise

for (ego_user, modelname) in zip(args.Datasets, args.model_name):
    print(ego_user, modelname)
    target_dataset, target_data = load_data(ego_user, device)
    target_data = RandomNodeSplit(num_val=0, num_test=args.testratio)(target_data)
    
    res_dir = "/data/embed_result/" + modelname + "/" + ego_user
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)
    
    result_data_file = res_dir + "/smia_results.txt"

    if modelname == "GCN":
        model = GCN_Net(target_dataset, args).to(device)
    elif modelname == "GAT":
        model = GAT_Net(target_dataset, args).to(device)
    elif modelname == "SGC":
        model = SGC_Net(target_dataset, args).to(device)
    elif modelname == "SAGE":
        model = SAGE_Net(target_dataset, args).to(device)
    elif modelname == "GPRGNN":
        model = GPRGNN(target_dataset, args).to(device)
    elif modelname == "GCN_PLUS":
        model = GCN_PLUS_Net(target_data.num_features, args.hidden, target_dataset.num_classes, args.num_layers, args.dropout).to(device)
    elif modelname == "GAT_PLUS":
        model = GAT_PLUS_Net(target_data.num_features, args.hidden, target_dataset.num_classes, args.num_layers, args.dropout, args.heads, args.output_heads).to(device)
    elif modelname == "SGC_PLUS":
        model = SGC_PLUS_Net(target_data.num_features, args.hidden, target_dataset.num_classes, args.num_layers, args.dropout).to(device)
    elif modelname == "NLGCN":
        model = NLGCN(target_dataset, args).to(device)
    elif modelname == "NLGAT":
        model = NLGAT(target_dataset, args).to(device)
    elif modelname == "NLMLP":
        model = NLMLP(target_dataset, args).to(device)
    elif modelname == "MLP":
        model = MLP(
        in_channels = target_dataset.num_features, 
        hidden_channels = args.hidden,
        out_channels = target_dataset.num_classes, 
        num_layers = 2,
        dropout = args.dropout,
        InputNorm=False).to(device)

    if modelname == "GPRGNN":
        optimizer = optim.Adam([{'params': model.lin1.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
            {'params': model.lin2.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
            {'params': model.prop1.parameters(), 'weight_decay': 0.0, 'lr': args.lr}], lr=args.lr)
    else:
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    def train(epoch):
        t = time.time()
        loss_fn = nn.CrossEntropyLoss(reduction='mean')
        model.train()
        optimizer.zero_grad()
        output, embed = model(target_data)#model(target_data.x, target_data.adj_t)
        loss_train = loss_fn(output[target_data.train_mask], target_data.y[target_data.train_mask])
        acc_train = accuracy(output[target_data.train_mask], target_data.y[target_data.train_mask])
        loss_train.backward()
        optimizer.step()

        if not args.fastmode:
            model.eval()
            output, embed = model(target_data)#model(target_data.x, target_data.adj_t)

        if (epoch+1) % 100==0:
            print('Epoch: {:04d}'.format(epoch+1),
                  'loss_train: {:.4f}'.format(loss_train.item()),
                  'acc_train: {:.4f}'.format(acc_train.item()),
                  'time: {:.4f}s'.format(time.time() - t))
        return loss_train.item(),acc_train.item(),output

    def train_soft(epoch, stage='first'):
        t = time.time()
        loss_fn = nn.CrossEntropyLoss(reduction='mean')
        model_2stage.train()
        optimizer.zero_grad()
        if stage == 'first':
            output, embed = model_2stage(target_data)#model_2stage(target_data.x, target_data.adj_t)
            loss_hard = loss_fn(output[target_data.train_mask], target_data.y[target_data.train_mask])
            acc_train = accuracy(output[target_data.train_mask], target_data.y[target_data.train_mask])
            if loss_hard > alpha_soft_set[ego_user]:
                loss = loss_hard
            else:
                loss = loss_fn(output[target_data.train_mask], target_data.soft_y[target_data.train_mask])
        else:
            output, embed = model_2stage(target_data)#model_2stage(target_data.x, target_data.adj_t)
            loss_hard = loss_fn(output[target_data.second_train_mask], target_data.second_y[target_data.second_train_mask])
            acc_train = accuracy(output[target_data.second_train_mask], target_data.second_y[target_data.second_train_mask])
            if loss_hard > alpha_soft_set[ego_user]:
                loss = loss_hard
            else:
                loss = loss_fn(output[target_data.second_train_mask], target_data.second_soft_y[target_data.second_train_mask])
                
        loss.backward()
        optimizer.step()
        
        if (epoch+1) % 100==0:
            print('Epoch: {:04d}'.format(epoch+1),
                  'loss_train: {:.4f}'.format(loss.item()),
                  'acc_train: {:.4f}'.format(acc_train.item()),
                  'time: {:.4f}s'.format(time.time() - t))
        return loss.item(),acc_train.item(),output

    def test():
        model.eval()
        loss_fn = nn.CrossEntropyLoss(reduction='mean')
        para={}
        cnt=0
        for p in model.parameters():
            p = p.cpu().detach().numpy()
            para[cnt]=p
            cnt+=1

        output, embed = model(target_data)#model(target_data.x, target_data.adj_t)
        loss_test = loss_fn(output[target_data.test_mask], target_data.y[target_data.test_mask])
        acc_test = accuracy(output[target_data.test_mask], target_data.y[target_data.test_mask])
        print("Test set results:",
              "loss= {:.4f}".format(loss_test.item()),
              "accuracy= {:.4f}".format(acc_test.item()))
        return output,loss_test.item(),acc_test.item(),para,embed

    def test_soft():
        model_2stage.eval()
        loss_fn = nn.CrossEntropyLoss(reduction='mean')
        para={}
        cnt=0
        for p in model_2stage.parameters():
            p = p.cpu().detach().numpy()
            para[cnt]=p
            cnt+=1

        output, embed = model_2stage(target_data)#model_2stage(target_data.x, target_data.adj_t)
        loss_test = loss_fn(output[target_data.test_mask], target_data.y[target_data.test_mask])
        acc_test = accuracy(output[target_data.test_mask], target_data.y[target_data.test_mask])
        print("Test set results:",
              "loss= {:.4f}".format(loss_test.item()),
              "accuracy= {:.4f}".format(acc_test.item()))
        return output,loss_test.item(),acc_test.item(),para,embed




    # Train model
    t_total = time.time()
    for epoch in range(args.epochs):
        loss_train, acc_train, output_train = train(epoch)
    train_loss_acc.append([loss_train, acc_train])

    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    print(output_train)
    output_test, loss_test, acc_test, para, embed = test()
    test_loss_acc.append([loss_test, acc_test])
    with open(result_data_file, 'a') as data_file:
        data_file.write(f"{acc_test}\n") 
    output_test = output_test.cpu().detach().numpy()
    savepath = res_dir + '/Normal-Posterior-layer2-3smia.npy'
    print('***')
    np.save(savepath, np.array(output_test))

    if modelname == "GCN":
        model_2stage = GCN_Net(target_dataset, args).to(device)
    elif modelname == "GAT":
        model_2stage = GAT_Net(target_dataset, args).to(device)
    elif modelname == "SGC":
        model_2stage = SGC_Net(target_dataset, args).to(device)
    elif modelname == "SAGE":
        model_2stage = SAGE_Net(target_dataset, args).to(device)
    elif modelname == "GPRGNN":
        model_2stage = GPRGNN(target_dataset, args).to(device)
    elif modelname == "GCN_PLUS":
        model_2stage = GCN_PLUS_Net(target_data.num_features, args.hidden, target_dataset.num_classes, args.num_layers, args.dropout).to(device)
    elif modelname == "GAT_PLUS":
        model_2stage = GAT_PLUS_Net(target_data.num_features, args.hidden, target_dataset.num_classes, args.num_layers, args.dropout, args.heads, args.output_heads).to(device)
    elif modelname == "SGC_PLUS":
        model_2stage = SGC_PLUS_Net(target_data.num_features, args.hidden, target_dataset.num_classes, args.num_layers, args.dropout).to(device)
    elif modelname == "NLGCN":
        model_2stage = NLGCN(target_dataset, args).to(device)
    elif modelname == "NLGAT":
        model_2stage = NLGAT(target_dataset, args).to(device)
    elif modelname == "NLMLP":
        model_2stage = NLMLP(target_dataset, args).to(device)

    if modelname == "GPRGNN":
        optimizer = optim.Adam([{'params': model_2stage.lin1.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
            {'params': model_2stage.lin2.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
            {'params': model_2stage.prop1.parameters(), 'weight_decay': 0.0, 'lr': args.lr}], lr=args.lr)
    else:
        optimizer = optim.Adam(model_2stage.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    onehot = F.one_hot(target_data.y, num_classes=target_dataset.num_classes).to(torch.float).to(device)
    inverse_onthot = 1.0 - onehot
    target_data.soft_y = args.upper * onehot + (1.0 - args.upper) / (target_dataset.num_classes - 1) * inverse_onthot
    print(target_data.soft_y.shape, target_data.y.shape)
    
    #First Stage
    t_total = time.time()
    for epoch in range(args.epochs):
        loss_train, acc_train, output_train = train_soft(epoch, stage='first')
    train_loss_acc.append([loss_train, acc_train])
    output_test, loss_test, acc_test, para, embed = test_soft()
    test_loss_acc.append([loss_test, acc_test])
    
    preds = []
    for _, mask in target_data('train_mask', 'test_mask'):
        pred = output_test[mask].max(1)[1]
        preds.append(pred.detach().cpu())

    #Second Stage
    test_preds = preds[-1]
    target_data.second_y = target_data.y.clone()
    target_data.second_y[target_data.test_mask] = test_preds.to(device)

    onehot = F.one_hot(target_data.second_y, num_classes=target_dataset.num_classes).to(torch.float).to(device)
    inverse_onthot = 1.0 - onehot
    target_data.second_soft_y = args.upper * onehot + (1.0 - args.upper) / (target_dataset.num_classes - 1) * inverse_onthot
    target_data.second_train_mask, target_data.second_test_mask = target_data.test_mask, target_data.train_mask
    
    for epoch in range(args.epochs):
        loss_train, acc_train, output_train = train_soft(epoch, stage='second')
    train_loss_acc.append([loss_train, acc_train])
    output_test, loss_test, acc_test, para, embed = test_soft()
    test_loss_acc.append([loss_test, acc_test])

    with open(result_data_file, 'a') as data_file:
        data_file.write(f"{acc_test}\n") 

    output_test = output_test.cpu().detach().numpy()
    savepath = res_dir + '/2stage-Posterior-layer2-3smia.npy'
    np.save(savepath, np.array(output_test))
