from __future__ import division
from __future__ import print_function
import math
import time
import argparse
import numpy as np
import os
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import shap
from utils import *
from models import GCN_pia, Net#, MLP
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import AttributedGraphDataset, LastFMAsia
from torch_geometric.transforms import RandomNodeSplit
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, recall_score, precision_score
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
from sklearn.preprocessing import label_binarize
import itertools
import pandas as pd
import pickle as pkl
import networkx as nx
from deepset import HalfNLHconv, MLP

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--target_dataname', type=list, default=['CiteSeer', 'Facebook', 'LastFM', "chameleon"])
parser.add_argument('--shadow_dataname', type=list, default=['CiteSeer', 'Facebook', 'LastFM', 'chameleon'])
parser.add_argument('--model_name', type=list, default=['MLP', 'MLP', 'MLP', "MLP"])
parser.add_argument('--no-cuda', action='store_true', default=True,
                    help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
                    help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=50,
                    help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=1e-3,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=1e-5,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=32,
                    help='Number of hidden units.')
parser.add_argument('--dropout', type=float, default=0,
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--trainratio', type=float, default=0.8)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_per_class', type=int, default=500)
parser.add_argument('--num_smia', type=int, default=3)
parser.add_argument('--warm_ratio', type=float, default=0.2)

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
class CustomDataset(Dataset):
    def __init__(self, features, data_index, labels, indices):
        self.features = features[indices]
        self.data_index = data_index[indices]
        self.labels = labels[indices]

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.data_index[idx], self.labels[idx]

class CustomDataset_Shadow(Dataset):
    def __init__(self, features, labels, indices):
        self.features = features[indices]
        self.labels = labels[indices]

    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

class TargetDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]  

def evaluate_model(model, loader, data_file, device):
    model.eval()  

    all_outputs = []
    all_labels = []

    with torch.no_grad():
        for batch in loader:
            inputs, labels = batch  
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs) 

            all_outputs.append(outputs.cpu().detach())
            all_labels.append(labels.cpu().detach())

    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)

    _, y_pred_tensor = torch.max(all_outputs, dim=1)
    y_pred = y_pred_tensor.numpy()
    y_true = all_labels.numpy()

    probabilities = torch.softmax(all_outputs, dim=1)
    y_prob = probabilities.numpy()

    Accuracy = accuracy_score(y_true, y_pred)
    print(f"Accuracy: {Accuracy}")

    Precision_macro = precision_score(y_true, y_pred, average='macro')
    Recall_macro = recall_score(y_true, y_pred, average='macro')
    F1_macro = f1_score(y_true, y_pred, average='macro')
    print(f"Precision (macro): {Precision_macro}")
    print(f"Recall (macro): {Recall_macro}")
    print(f"F1 Score (macro): {F1_macro}")

    classes = np.unique(y_true)
    y_true_bin = label_binarize(y_true, classes=classes)

    roc_auc = roc_auc_score(y_true_bin, y_prob, average='macro', multi_class='ovr')
    print(f"ROC AUC Score (macro): {roc_auc}")

    with open(data_file, 'a') as data_file:
        data_file.write(f"{Accuracy} {Precision_macro} {Recall_macro} {F1_macro} {roc_auc}\n")


def get_all_outputs(ppv, cliques, cliques_edgeindex, hoppaths, hoppaths_edgeindex, othersub, othersub_edgeindex, transform):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform.eval()  # Set the transform to evaluation mode

    data_outputs = []
    data_labels = []

    # Helper function to process each subgraph type
    def process_subgraphs(nodes_list, edge_indices_list, label):
        with torch.no_grad():
            for nodes, edge_index in zip(nodes_list, edge_indices_list):
                # Get node features
                emb_list = [ppv[idx] for idx in nodes]
                emb = np.stack(emb_list)  # Shape: [num_nodes_in_subgraph, feature_dim]
                x = torch.tensor(emb, dtype=torch.float).to(device)
                # Convert edge_index to tensor
                edge_index = torch.tensor(edge_index, dtype=torch.long).to(device)
                # Build norm (set to ones)
                norm = torch.ones(edge_index.size(1), dtype=torch.float).to(device)
                # Pass through the transform
                output = transform(x, edge_index, norm, aggr='add')
                # Aggregate node features to get graph-level representation
                graph_emb = output.mean(dim=0, keepdim=True)  # Shape: [1, out_dim]
                data_outputs.append(graph_emb.cpu().numpy())
                data_labels.append(label)

    process_subgraphs(cliques, cliques_edgeindex, label=0)

    process_subgraphs(hoppaths, hoppaths_edgeindex, label=1)

    process_subgraphs(othersub, othersub_edgeindex, label=2)

    # Convert outputs and labels to numpy arrays
    data_outputs = np.concatenate(data_outputs, axis=0)  # Shape: [total_samples, out_dim]
    data_labels = np.array(data_labels)

    return data_outputs, data_labels


    
def generate_dataset(ppv, cliques, cliques_edgeindex, hoppaths, hoppaths_edgeindex, othersub, othersub_edgeindex):
    data_list = []
    edge_index_list = []
    labels = []

    def process_paths(paths, edge_index, label):
        for path, edge in zip(paths, edge_index):
            emb_list = [ppv[idx] for idx in path]
            emb = np.stack(emb_list)
            data_list.append(emb)
            edge_index_list.append(edge)
            labels.append(label)

    process_paths(cliques, cliques_edgeindex, label=0)
    process_paths(hoppaths, hoppaths_edgeindex, label=1)
    process_paths(othersub, othersub_edgeindex, label=2)

    return np.array(data_list), edge_index_list, np.array(labels)

def split_data_by_label(ratio):
    indices_0 = np.arange(0, args.num_per_class)
    indices_1 = np.arange(args.num_per_class, args.num_per_class * 2)
    indices_2 = np.arange(args.num_per_class * 2, args.num_per_class * 3)
    
    def split_indices(indices, ratio):
        num_test = int(len(indices) * ratio)
        shuffled_indices = np.random.permutation(indices)
        train_indices = shuffled_indices[:num_test]
        test_indices = shuffled_indices[num_test:]
        return train_indices, test_indices
    
    train_indices_0, test_indices_0 = split_indices(indices_0, ratio)
    train_indices_1, test_indices_1 = split_indices(indices_1, ratio)
    train_indices_2, test_indices_2 = split_indices(indices_2, ratio)
    
    train_indices = np.concatenate([train_indices_0, train_indices_1, train_indices_2])
    test_indices = np.concatenate([test_indices_0, test_indices_1, test_indices_2])
    
    return train_indices.astype(int), test_indices.astype(int)
    
def match_input_feature(data, additional_columns): 
        data = np.pad(data, ((0, 0), (0, additional_columns)), mode='constant', constant_values=0)
        return data

for (targetdataset, shadowdataset, modelname) in zip(args.target_dataname, args.shadow_dataname, args.model_name):
    print(targetdataset, modelname)  
    target_ebd_dir = "/data/embed_result/" + modelname + "/" + targetdataset
    shadow_ebd_dir = "/data/embed_result/" + modelname + "/" + shadowdataset
    target_sample_dir = "/data/sample_result/" + targetdataset
    shadow_sample_dir = "/data/sample_result/" + shadowdataset

    result_data_file = target_ebd_dir + "/4smia" + shadowdataset + "_defense_results.txt"

    savepath = shadow_ebd_dir + '/Shadow-Posterior-layer2-3smia.npy'
    ppv_shadow = np.load(savepath)
    savepath = target_ebd_dir + '/2stage-Posterior-layer2-3smia.npy'
    ppv_2stage = np.load(savepath)
    
    savepath = shadow_sample_dir + '/sampled_cliques_4smia.npy'
    shadow_cliques = np.load(savepath)  
    savepath = shadow_sample_dir + '/sampled_cliques_edgeindex_4smia.npy'
    shadow_cliques_edgeindex = np.load(savepath)              
    savepath = shadow_sample_dir + '/sampled_2hop_paths_4smia.npy'
    shadow_hoppaths = np.load(savepath)
    savepath = shadow_sample_dir + '/sampled_2hop_paths_edgeindex_4smia.npy'
    shadow_hoppaths_edgeindex = np.load(savepath)
    savepath = shadow_sample_dir + '/sampled_non_clique_non_2hop_4smia.npy'
    shadow_othersub = np.load(savepath)
    savepath = shadow_sample_dir + '/sampled_non_clique_non_2hop_edgeindex_4smia.npy'
    shadow_othersub_edgeindex = np.load(savepath, allow_pickle=True)
    savepath = target_sample_dir + '/sampled_cliques_4smia.npy'
    cliques = np.load(savepath)
    savepath = target_sample_dir + '/sampled_cliques_edgeindex_4smia.npy'
    cliques_edgeindex = np.load(savepath)
    savepath = target_sample_dir + '/sampled_2hop_paths_4smia.npy'
    hoppaths = np.load(savepath)
    savepath = target_sample_dir + '/sampled_2hop_paths_edgeindex_4smia.npy'
    hoppaths_edgeindex = np.load(savepath)
    savepath = target_sample_dir + '/sampled_non_clique_non_2hop_4smia.npy'
    othersub = np.load(savepath)
    savepath = target_sample_dir + '/sampled_non_clique_non_2hop_edgeindex_4smia.npy'
    othersub_edgeindex = np.load(savepath, allow_pickle=True)

  
    data_list, edge_index_list, labels = generate_dataset(ppv_shadow, shadow_cliques, shadow_cliques_edgeindex, shadow_hoppaths, shadow_hoppaths_edgeindex, shadow_othersub, shadow_othersub_edgeindex)
    data_index = np.array(list(range(0, len(edge_index_list))))
    data_list, data_index, labels = torch.from_numpy(data_list).float(), torch.from_numpy(data_index), torch.from_numpy(labels)

    train_indices, test_indices = split_data_by_label(args.trainratio)
    train_indices, test_indices = torch.from_numpy(train_indices), torch.from_numpy(test_indices)

    train_dataset = CustomDataset(data_list, data_index, labels, train_indices)
    test_dataset = CustomDataset(data_list, data_index, labels, test_indices)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
 
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = HalfNLHconv(
        in_dim=ppv_shadow.shape[1],
        hid_dim=32,
        out_dim=ppv_shadow.shape[1],
        num_layers=2,
        dropout=0.0,
        Normalization='bn',
        InputNorm=False,
        attention=False
    ).to(device)

    mlpclassifier = MLP(
        in_channels = ppv_shadow.shape[1], 
        hidden_channels = args.hidden,
        out_channels = 3, 
        num_layers = 3,
        dropout = args.dropout,
        InputNorm=False).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer1 = torch.optim.Adam(list(transform.parameters()) + list(mlpclassifier.parameters()), lr=5e-4)
    optimizer2 = torch.optim.Adam(mlpclassifier.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    for epoch in range(args.epochs):
        if epoch < args.warm_ratio * args.epochs:
            optimizer = optimizer1
            transform.train()
        else:
            optimizer = optimizer2
            transform.eval()
        mlpclassifier.train()
        epoch_loss = 0
        for i, (batch_data, batch_data_index, batch_labels) in enumerate(train_loader):
            batch_data, batch_data_index, batch_labels = batch_data.to(device), batch_data_index.to(device), batch_labels.to(device)

            batch_embeddings = []
            #batch_targets = []

            for x, index, label in zip(batch_data, batch_data_index, batch_labels):
                #x = torch.tensor(emb, dtype=torch.float).to(device)
                edge_index = torch.from_numpy(edge_index_list[index]).to(device)#, dtype=torch.long).to(device)
   
                norm = torch.ones(edge_index.size(1), dtype=torch.float).to(device)
     
                output = transform(x, edge_index, norm, aggr='add')
        
                graph_emb = output.mean(dim=0, keepdim=True)  # [1, out_dim]
                batch_embeddings.append(graph_emb)
                #batch_targets.append(label)
      
            batch_embeddings = torch.cat(batch_embeddings, dim=0)  # [batch_size, out_dim]
            #batch_targets = torch.tensor(batch_targets, dtype=torch.long).to(device)
    
            optimizer.zero_grad()
            outputs = mlpclassifier(batch_embeddings)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() #* batch_embeddings.size(0)

        avg_loss = epoch_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{args.epochs}], Loss: {avg_loss:.4f}')

        if epoch % 10 == 0:
 
            transform.eval()
            mlpclassifier.eval()
            test_acc = 0.0
            with torch.no_grad():
                for i, (batch_data, batch_data_index, batch_labels) in enumerate(test_loader):
                    batch_data, batch_data_index, batch_labels = batch_data.to(device), batch_data_index.to(device), batch_labels.to(device)

                    val_embeddings = []
                    #val_targets = []
                    
                    for x, index, label in zip(batch_data, batch_data_index, batch_labels):
                        #x = torch.tensor(emb, dtype=torch.float).to(device)
                        edge_index = torch.from_numpy(edge_index_list[index]).to(device)#, dtype=torch.long).to(device)
                        norm = torch.ones(edge_index.size(1), dtype=torch.float).to(device)
                        output = transform(x, edge_index, norm, aggr='add')
                        graph_emb = output.mean(dim=0, keepdim=True)
                        val_embeddings.append(graph_emb.cpu())
                        #val_targets.append(label)

                    val_embeddings = torch.cat(val_embeddings, dim=0)  # [num_val_samples, out_dim]
                    #val_targets = torch.tensor(val_targets, dtype=torch.long)

             
                    val_outputs = mlpclassifier(val_embeddings.to(device))
                    pred = val_outputs.max(1)[1]
                    test_acc += pred.eq(batch_labels).sum().item() / batch_labels.nelement() 
                test_acc = test_acc / len(test_loader)
                print(test_acc)


    attack_data_shadow, attack_data_label_shadow = get_all_outputs(ppv_shadow, shadow_cliques, shadow_cliques_edgeindex, shadow_hoppaths, shadow_hoppaths_edgeindex, shadow_othersub, shadow_othersub_edgeindex, transform)
    attack_data_2stage, attack_data_label_2stage = get_all_outputs(ppv_2stage, cliques, cliques_edgeindex, hoppaths, hoppaths_edgeindex, othersub, othersub_edgeindex, transform)
    
    test_dataset_shadow = CustomDataset_Shadow(attack_data_shadow, attack_data_label_shadow, test_indices)
    test_dataset_2stage = TargetDataset(attack_data_2stage, attack_data_label_2stage)

    test_loader_shadow = DataLoader(test_dataset_shadow, batch_size=args.batch_size, shuffle=False)
    test_loader_2stage = DataLoader(test_dataset_2stage, batch_size=args.batch_size, shuffle=False)

    mlpclassifier.eval()
    accs = [0.]
    for loader_id, loader in enumerate([test_loader_2stage]):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = mlpclassifier(x)
            pred = out.max(1)[1]
            accs[loader_id] += pred.eq(y).sum().item() / y.nelement()
        
        accs[loader_id] = accs[loader_id] / len(loader)
    print('Shadow:', test_acc)
    print('2stage:', accs[0])
    with open(result_data_file, 'a') as data_file:
        data_file.write(f"{test_acc} {accs[0]}\n")
    
    print("Shadow set metrics:")
    evaluate_model(mlpclassifier, test_loader_shadow, result_data_file, device)

    print("\n2stage set metrics:")
    evaluate_model(mlpclassifier, test_loader_2stage, result_data_file, device)
    with open(result_data_file, 'a') as data_file:
        data_file.write(f" \n")