from __future__ import division
from __future__ import print_function
import math
import time
import argparse
import numpy as np
import os
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import shap
from utils import *
from models import GCN_pia, Net
from deepset import MLP
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import AttributedGraphDataset, LastFMAsia
from torch_geometric.transforms import RandomNodeSplit
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, recall_score, precision_score
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
from sklearn.preprocessing import label_binarize, MinMaxScaler, StandardScaler
import itertools
import pandas as pd
import pickle as pkl
import networkx as nx
from deepset import HalfNLHconv


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--target_dataname', type=list, default=['CiteSeer', 'Facebook', 'LastFM', "chameleon"])
parser.add_argument('--shadow_dataname', type=list, default=['CiteSeer', 'Facebook', 'LastFM', 'chameleon'])
parser.add_argument('--model_name', type=list, default=['MLP', 'MLP', 'MLP', "MLP"])
parser.add_argument('--no-cuda', action='store_true', default=True,
                    help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
                    help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=50,
                    help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=1e-3,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=1e-5,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=32,
                    help='Number of hidden units.')
parser.add_argument('--dropout', type=float, default=0,
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--trainratio', type=float, default=0.8)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_per_class', type=int, default=500)
parser.add_argument('--num_smia', type=int, default=3)
                    
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)


def Similarity_vector(path, ppv):
    prob0 = []
    prob1 = []
    prob2 = []
    nd1_ = path[0]
    nd2_ = path[1]
    nd3_ = path[2]
    emb1 = ppv[nd1_]
    emb2 = ppv[nd2_]
    emb3 = ppv[nd3_]
    prob0.append(np.dot(emb1, emb2))
    prob1.append(
        np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2) + 0.0000000000000000000000000000001))
    prob2.append(np.linalg.norm(np.array(emb1) - np.array(emb2)))

    prob0.append(np.dot(emb1, emb3))
    prob1.append(
        np.dot(emb1, emb3) / (np.linalg.norm(emb1) * np.linalg.norm(emb3) + 0.0000000000000000000000000000001))
    prob2.append(np.linalg.norm(np.array(emb1) - np.array(emb3)))

    prob0.append(np.dot(emb2, emb3))
    prob1.append(
        np.dot(emb2, emb3) / (np.linalg.norm(emb2) * np.linalg.norm(emb3) + 0.0000000000000000000000000000001))
    prob2.append(np.linalg.norm(np.array(emb2) - np.array(emb3)))

    prob0.sort()
    prob1.sort()
    prob2.sort()
    prob = [prob0, prob1, prob2]
    prob = np.array(list(itertools.chain.from_iterable(prob)))
    return prob
        
def generate_target_attack_dataset(ppv_2stage, cliques, hoppaths, othersub):        

    attack_data_2stage = []
    attack_data_label_2stage = []
    for path in cliques:
        prob = Similarity_vector(path, ppv_2stage)
        attack_data_2stage.append(prob)
        attack_data_label_2stage.append(0)
    for path in hoppaths:
        prob = Similarity_vector(path, ppv_2stage)
        attack_data_2stage.append(prob)
        attack_data_label_2stage.append(1)
    for path in othersub:
        prob = Similarity_vector(path, ppv_2stage)
        attack_data_2stage.append(prob)
        attack_data_label_2stage.append(2)
    return np.array(attack_data_2stage), np.array(attack_data_label_2stage)
    
def generate_shadow_attack_dataset(ppv_shadow, cliques, hoppaths, othersub):        
    attack_data_shadow = []
    attack_data_label_shadow = []
    for path in cliques:
        prob = Similarity_vector(path, ppv_shadow)
        attack_data_shadow.append(prob)
        attack_data_label_shadow.append(0)
    for path in hoppaths:
        prob = Similarity_vector(path, ppv_shadow)
        attack_data_shadow.append(prob)
        attack_data_label_shadow.append(1)
    for path in othersub:
        prob = Similarity_vector(path, ppv_shadow)
        attack_data_shadow.append(prob)
        attack_data_label_shadow.append(2)
        
    return np.array(attack_data_shadow), np.array(attack_data_label_shadow)

class CustomDataset(Dataset):
    def __init__(self, features, labels, indices):
        self.features = features[indices]
        self.labels = labels[indices]

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

class TargetDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]  
    
def split_data_by_label(ratio):
    indices_0 = np.arange(0, args.num_per_class)
    indices_1 = np.arange(args.num_per_class, args.num_per_class * 2)
    indices_2 = np.arange(args.num_per_class * 2, args.num_per_class * 3)
    
    def split_indices(indices, ratio):
        num_test = int(len(indices) * ratio)
        shuffled_indices = np.random.permutation(indices)
        train_indices = shuffled_indices[:num_test]
        test_indices = shuffled_indices[num_test:]
        return train_indices, test_indices
    
    train_indices_0, test_indices_0 = split_indices(indices_0, ratio)
    train_indices_1, test_indices_1 = split_indices(indices_1, ratio)
    train_indices_2, test_indices_2 = split_indices(indices_2, ratio)
    
    train_indices = np.concatenate([train_indices_0, train_indices_1, train_indices_2])
    test_indices = np.concatenate([test_indices_0, test_indices_1, test_indices_2])
    
    return train_indices.astype(int), test_indices.astype(int)

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)


def evaluate_model(model, loader, data_file, device):
    model.eval()  

    all_outputs = []
    all_labels = []

    with torch.no_grad():
        for batch in loader:
            inputs, labels = batch  
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs) 

            all_outputs.append(outputs.cpu().detach())
            all_labels.append(labels.cpu().detach())

    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)

    _, y_pred_tensor = torch.max(all_outputs, dim=1)
    y_pred = y_pred_tensor.numpy()
    y_true = all_labels.numpy()

    probabilities = torch.softmax(all_outputs, dim=1)
    y_prob = probabilities.numpy()

    Accuracy = accuracy_score(y_true, y_pred)
    print(f"Accuracy: {Accuracy}")

    Precision_macro = precision_score(y_true, y_pred, average='macro')
    Recall_macro = recall_score(y_true, y_pred, average='macro')
    F1_macro = f1_score(y_true, y_pred, average='macro')
    print(f"Precision (macro): {Precision_macro}")
    print(f"Recall (macro): {Recall_macro}")
    print(f"F1 Score (macro): {F1_macro}")

    classes = np.unique(y_true)
    y_true_bin = label_binarize(y_true, classes=classes)

    roc_auc = roc_auc_score(y_true_bin, y_prob, average='macro', multi_class='ovr')
    print(f"ROC AUC Score (macro): {roc_auc}")

    with open(data_file, 'a') as data_file:
        data_file.write(f"{Accuracy} {Precision_macro} {Recall_macro} {F1_macro} {roc_auc}\n")

for (targetdataset, shadowdataset, modelname) in zip(args.target_dataname, args.shadow_dataname, args.model_name):
    print(targetdataset, modelname)  
    target_ebd_dir = "/data/embed_result/" + modelname + "/" + targetdataset
    shadow_ebd_dir = "/data/embed_result/" + modelname + "/" + shadowdataset
    target_sample_dir = "/data/sample_result/" + targetdataset
    shadow_sample_dir = "/data/sample_result/" + shadowdataset

    result_data_file = target_ebd_dir + "/4smia" + shadowdataset + "_Simi_defense_results.txt"

    savepath = shadow_ebd_dir + '/Shadow-Posterior-layer2-3smia.npy'
    ppv_shadow = np.load(savepath)
    savepath = target_ebd_dir + '/2stage-Posterior-layer2-3smia.npy'
    ppv_2stage = np.load(savepath)
    
    savepath = shadow_sample_dir + '/sampled_cliques_4smia.npy'
    shadow_cliques = np.load(savepath)        
    savepath = shadow_sample_dir + '/sampled_2hop_paths_4smia.npy'
    shadow_hoppaths = np.load(savepath)
    savepath = shadow_sample_dir + '/sampled_non_clique_non_2hop_4smia.npy'
    shadow_othersub = np.load(savepath)

    savepath = target_sample_dir + '/sampled_cliques_4smia.npy'
    cliques = np.load(savepath)
    savepath = target_sample_dir + '/sampled_2hop_paths_4smia.npy'
    hoppaths = np.load(savepath)
    savepath = target_sample_dir + '/sampled_non_clique_non_2hop_4smia.npy'
    othersub = np.load(savepath)

    attack_data_shadow, attack_data_label_shadow = generate_shadow_attack_dataset(ppv_shadow, shadow_cliques, shadow_hoppaths, shadow_othersub) 
    attack_data_2stage, attack_data_label_2stage = generate_target_attack_dataset(ppv_2stage, cliques, hoppaths, othersub)

    scaler = StandardScaler()
    combined_data = np.concatenate((attack_data_shadow, attack_data_2stage), axis=0)
    scaler.fit(combined_data)

    attack_data_shadow = scaler.transform(attack_data_shadow)
    attack_data_2stage = scaler.transform(attack_data_2stage)

    attack_data_shadow, attack_data_label_shadow = torch.from_numpy(attack_data_shadow).float(), torch.from_numpy(attack_data_label_shadow)
    attack_data_2stage, attack_data_label_2stage = torch.from_numpy(attack_data_2stage).float(), torch.from_numpy(attack_data_label_2stage)

    train_indices_shadow, test_indices_shadow = split_data_by_label(args.trainratio)
    train_indices_shadow, test_indices_shadow = torch.from_numpy(train_indices_shadow), torch.from_numpy(test_indices_shadow)

    train_dataset = CustomDataset(attack_data_shadow, attack_data_label_shadow, train_indices_shadow)
    test_dataset = CustomDataset(attack_data_shadow, attack_data_label_shadow, test_indices_shadow)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
    

    attacker_model = MLP(in_channels = 3 * math.comb(args.num_smia, 2), 
                         hidden_channels = args.hidden,
                         out_channels = 3, 
                         num_layers = 3,
                         dropout = args.dropout,
                         InputNorm=False).to(device)
    optimizer = torch.optim.Adam(attacker_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(args.epochs):
        attacker_model.train()
        train_ave_loss = 0
        for i, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = attacker_model(x)
            loss = criterion(out, y)
            train_ave_loss += loss
            loss.backward()
            optimizer.step()
        
        if (epoch % 10 == 0):
            print(train_ave_loss / len(train_loader))
            attacker_model.eval()   
            
            train_acc = 0.0
            for i, (x, y) in enumerate(train_loader):
                x, y = x.to(device), y.to(device)
                out = attacker_model(x)
                pred = out.max(1)[1]
                train_acc += pred.eq(y).sum().item() / y.nelement() 
            train_acc = train_acc / len(train_loader)

            test_acc = 0.0
            for i, (x, y) in enumerate(test_loader):
                x, y = x.to(device), y.to(device)
                out = attacker_model(x)
                pred = out.max(1)[1]
                test_acc += pred.eq(y).sum().item() / y.nelement() 
            test_acc = test_acc / len(test_loader)
            print(train_acc, test_acc)
    #Test
    test_dataset_2stage = TargetDataset(attack_data_2stage, attack_data_label_2stage)
    test_loader_2stage = DataLoader(test_dataset_2stage, batch_size=args.batch_size, shuffle=False)
    
    attacker_model.eval()
    accs = [0.]
    with torch.no_grad():
        for loader_id, loader in enumerate([test_loader_2stage]):
            for x, y in loader:
                x, y = x.to(device), y.to(device)
                out = attacker_model(x)
                pred = out.max(1)[1]
                accs[loader_id] += pred.eq(y).sum().item() / y.nelement()
            
            accs[loader_id] = accs[loader_id] / len(loader)
    
    print('2stage:', accs[0])
    with open(result_data_file, 'a') as data_file:
        data_file.write(f"{test_acc} {accs[0]}\n")

    print("Shadow set metrics:")
    evaluate_model(attacker_model, test_loader, result_data_file, device)

    print("\n2stage set metrics:")
    evaluate_model(attacker_model, test_loader_2stage, result_data_file, device)
    with open(result_data_file, 'a') as data_file:
        data_file.write(f" \n")