from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np
import os
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
import shap
from utils import *
from models import GCN_pia, GCN_Net, GAT_Net, SGC_Net, SAGE_Net, GPRGNN, GCN_PLUS_Net, GAT_PLUS_Net, SGC_PLUS_Net, NLGCN, NLGAT, NLMLP
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import AttributedGraphDataset, LastFMAsia
from torch_geometric.transforms import RandomNodeSplit
from sklearn.metrics import roc_auc_score, recall_score, precision_score
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import torch_geometric.transforms as T
import pandas as pd
import pickle as pkl
import networkx as nx

def edge_index_to_adjacency_matrix(edge_index, num_nodes):  
    adjacency_matrix = torch.zeros(num_nodes, num_nodes, dtype=torch.float32)  
    adjacency_matrix[edge_index[0], edge_index[1]] = 1  
    adjacency_matrix[edge_index[1], edge_index[0]] = 1  
    return adjacency_matrix

def load_data(dataset_name, device):
    if dataset_name == "Cora":
        dataset = Planetoid(root='/data', name='Cora')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "CiteSeer":
        dataset = Planetoid(root='/data', name='CiteSeer')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "PubMed":
        dataset = Planetoid(root='/data', name='PubMed')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "computers":
        dataset = Amazon(root='/data/computers', name='computers')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "photo":
        dataset = Amazon(root='/data/photo', name='photo')
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "chameleon":
        preProcDs = WikipediaNetwork(root='/data', name='chameleon', geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(root='/data', name='chameleon', geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "squirrel":
        preProcDs = WikipediaNetwork(root='/data', name='squirrel', geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(root='/data', name='squirrel', geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "texas":
        dataset = WebKB(root='/data', name = "Texas")
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif dataset_name == "ogbn_arxiv":
        dataset = PygNodePropPredDataset(root='/data', name='ogbn-arxiv', transform=T.ToSparseTensor())
        data = dataset[0]
        data.adj_t = data.adj_t.to_symmetric()
        edge_index = data.adj_t.coo()
        data.edge_index = torch.stack([edge_index[0], edge_index[1]], dim=0)
        data.y = data.y.squeeze(1)
        data = data.to(device)
    elif 'cSBM' in dataset_name:
        path = '/data/' 
        dataset = dataset_ContextualSBM(path, name=dataset_name)
        data = dataset[0]
        data.adj_t = edge_index_to_adjacency_matrix(data.edge_index, data.num_nodes)
        data = data.to(device)
    elif 'Facebook' in dataset_name:
            path = '/data/Facebook' 
            dataset = AttributedGraphDataset(path, name=dataset_name)
            data = dataset[0]
            data.y = data.y.max(1)[1]
            data = data.to(device)
    elif 'LastFM' in dataset_name:
        path = '/data/LastFM' 
        dataset = LastFMAsia(path)
        data = dataset[0].to(device)
    return dataset, data

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Training settings
parser = argparse.ArgumentParser() 
parser.add_argument('--Datasets', type=list, default=["chameleon", "chameleon", "chameleon", "chameleon"])
parser.add_argument('--model_name', type=list, default=['NLGCN', 'NLGAT', 'NLMLP', 'GPRGNN'])
parser.add_argument('--no-cuda', action='store_true', default=True,
                    help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
                    help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200,
                    help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=5e-2,#1e-3,#0.02,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=0,#5e-4,#5e-5,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=64,#32
                    help='Number of hidden units.')
parser.add_argument('--heads', type=int, default=8,
                    help='Number of attension head.')
parser.add_argument('--num_layers', type=int, default=3,
                    help='Number of layers.')
parser.add_argument('--output_heads', type=int, default=3)
parser.add_argument('--Init', type=str, default='PPR')
parser.add_argument('--K', type=int, default=10)
parser.add_argument('--ppnp', type=str, default='GPR_prop')
parser.add_argument('--Gamma', type=float, default=None)
parser.add_argument('--dprate', type=float, default=0.7)#0.5)
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--disp', type=int, default=10)
parser.add_argument('--dropout', type=float, default=0.5,#0,
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--dropout1', type=float, default=0.5)
parser.add_argument('--dropout2', type=float, default=0.5)
parser.add_argument('--kernel', type=int, default=5)
parser.add_argument('--testratio', type=float, default=0.2)

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

test_loss_acc=[]
train_loss_acc=[]
result_train=[]
result_test=[]

for (ego_user, modelname) in zip(args.Datasets, args.model_name):
    print(ego_user, modelname)
    shadow_dataset, shadow_data = load_data(ego_user, device)
    shadow_data = RandomNodeSplit(num_val=0, num_test=args.testratio)(shadow_data)
    
    res_dir = "/data/embed_result/" + modelname + "/" + ego_user
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    if modelname == "GCN":
        model = GCN_Net(shadow_dataset, args).to(device)
    elif modelname == "GAT":
        model = GAT_Net(shadow_dataset, args).to(device)
    elif modelname == "SGC":
        model = SGC_Net(shadow_dataset, args).to(device)
    elif modelname == "SAGE":
        model = SAGE_Net(shadow_dataset, args).to(device)
    elif modelname == "GPRGNN":
        model = GPRGNN(shadow_dataset, args).to(device)
    elif modelname == "GCN_PLUS":
        model = GCN_PLUS_Net(shadow_data.num_features, args.hidden, shadow_dataset.num_classes, args.num_layers, args.dropout).to(device)
    elif modelname == "GAT_PLUS":
        model = GAT_PLUS_Net(shadow_data.num_features, args.hidden, shadow_dataset.num_classes, args.num_layers, args.dropout, args.heads, args.output_heads).to(device)
    elif modelname == "SGC_PLUS":
        model = SGC_PLUS_Net(shadow_data.num_features, args.hidden, shadow_dataset.num_classes, args.num_layers, args.dropout).to(device)
    elif modelname == "NLGCN":
        model = NLGCN(shadow_dataset, args).to(device)
    elif modelname == "NLGAT":
        model = NLGAT(shadow_dataset, args).to(device)
    elif modelname == "NLMLP":
        model = NLMLP(shadow_dataset, args).to(device)

    if modelname == "GPRGNN":
        optimizer = optim.Adam([{'params': model.lin1.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
            {'params': model.lin2.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
            {'params': model.prop1.parameters(), 'weight_decay': 0.0, 'lr': args.lr}], lr=args.lr)
    else:
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    
    def train(epoch):
        t = time.time()
        loss_fn = nn.CrossEntropyLoss(reduction='mean')
        model.train()
        optimizer.zero_grad()
        output, embed = model(shadow_data)#model(shadow_data.x, shadow_data.adj_t)
        loss_train = loss_fn(output[shadow_data.train_mask], shadow_data.y[shadow_data.train_mask])
        acc_train = accuracy(output[shadow_data.train_mask], shadow_data.y[shadow_data.train_mask])
        loss_train.backward()
        optimizer.step()

        if not args.fastmode:
            model.eval()
            output, embed = model(shadow_data)#model(shadow_data.x, shadow_data.adj_t)

        if (epoch+1) % 100==0:
            print('Epoch: {:04d}'.format(epoch+1),
                  'loss_train: {:.4f}'.format(loss_train.item()),
                  'acc_train: {:.4f}'.format(acc_train.item()),
                  'time: {:.4f}s'.format(time.time() - t))
        return loss_train.item(),acc_train.item(),output
        
    def test():
        model.eval()
        loss_fn = nn.CrossEntropyLoss(reduction='mean')
        para={}
        cnt=0
        for p in model.parameters():
            p = p.cpu().detach().numpy()
            para[cnt]=p
            cnt+=1

        output, embed =model(shadow_data)#model(shadow_data.x, shadow_data.adj_t)
        loss_test = loss_fn(output[shadow_data.test_mask], shadow_data.y[shadow_data.test_mask])
        acc_test = accuracy(output[shadow_data.test_mask], shadow_data.y[shadow_data.test_mask])
        print("Test set results:",
              "loss= {:.4f}".format(loss_test.item()),
              "accuracy= {:.4f}".format(acc_test.item()))
        return output,loss_test.item(),acc_test.item(),para,embed

    def save_model(net, seed):
        PATH = res_dir + '/3smia-shadow-{}.pth'.format(seed)
        torch.save(net.state_dict(), PATH)


    # Train model
    t_total = time.time()
    for epoch in range(args.epochs):
        loss_train, acc_train, output_train = train(epoch)
    train_loss_acc.append([loss_train, acc_train])

    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    print(output_train)
    output_test, loss_test, acc_test, para, embed = test()
    save_model(model, args.seed)
    test_loss_acc.append([loss_test, acc_test])
        
    output_train = output_train.cpu().detach().numpy()
    output_test = output_test.cpu().detach().numpy()
    embed = embed.cpu().detach().numpy()
    
    savepath = res_dir + '/Shadow-Posterior-layer2-3smia.npy'
    print('***')
    np.save(savepath, np.array(output_test))
    #np.save(savepath, np.array(embed))
    