import torch
torch.manual_seed(0)
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_max_pool, global_add_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import add_self_loops, degree, to_dense_adj
from torch_geometric.data import Data
import json
import copy
from torch import linalg as LA
import numpy as np
import pandas as pd
import argparse
import sys
import torch_geometric.transforms as T
import random
from tqdm import tqdm
from GCORN import GCORN, normalize_tensor_adj
import matplotlib.pyplot as plt
from torch_geometric.datasets import Planetoid, Coauthor
import os

def extract_split(dataset, data, fold):

    train_indices = data[fold]['model_selection'][0]['train']
    val_indices = data[fold]['model_selection'][0]['validation']
    test_indices = data[fold]['test']

    train_data = [dataset[i] for i in torch.tensor(train_indices)]
    val_data = [dataset[i] for i in torch.tensor(val_indices)]
    test_data = [dataset[i] for i in torch.tensor(test_indices)]

    return train_data, val_data, test_data

def sample_radius(epsilon, d) :
    return epsilon * ((np.random.uniform(0,1))**(1/(d-1)))

def split(dataset, split_type="random", num_train_per_class=20, num_val=500, num_test=1000):
    data = dataset.get(0)
    if split_type=="public" and hasattr(data, "train_mask"):
        train_mask = data.train_mask
        val_mask = data.val_mask
        test_mask = data.test_mask
    else:
        train_mask = torch.zeros_like(data.y, dtype=torch.bool)
        val_mask = torch.zeros_like(data.y, dtype=torch.bool)
        test_mask = torch.zeros_like(data.y, dtype=torch.bool)

        for c in range(dataset.num_classes):
            idx = (data.y == c).nonzero(as_tuple=False).view(-1)
            idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
            train_mask[idx] = True

        remaining = (~train_mask).nonzero(as_tuple=False).view(-1)
        remaining = remaining[torch.randperm(remaining.size(0))]

        val_mask[remaining[:num_val]] = True
        test_mask[remaining[num_val:num_val + num_test]] = True
    return (train_mask, val_mask, test_mask)





def eval_rob(model, data, radius,distance, sigma , device = torch.device("cuda")):
        model.eval()
        global_robustness_per_sigma = {}
        all_sigma = np.arange(0,1,0.01).tolist()
        for sig in all_sigma : 
            global_robustness_per_sigma[sig] = []
        data = data.to(device)
        adj_true = to_dense_adj(data.edge_index , max_num_nodes = data.x.size(0) )[0, :,:].to(device)
        norm_adj = normalize_tensor_adj(adj_true)
        with torch.no_grad():
            ground_thruth = model(data.x, norm_adj)
            ground_thruth = torch.softmax(ground_thruth, -1)
            # The number of nodes used for normalization
            n_nodes = ground_thruth.size(0)
            # The mass of each dimension
            d = data.x.size(-1)
            num_nodes = data.x.size(0)
            correct = 0
            correct_sigma = {}
            for sig in all_sigma : 
                correct_sigma[sig] = 0
            num_samples = 100
            for t_ in range(num_samples) :

                if distance == "L1" or distance == "L2" :
                    sampled_distance = sample_radius(radius, d) 
                    i_0 = np.random.randint(num_nodes)
                    r = [[sample_radius(sampled_distance, d)] if k!= i_0  else [sampled_distance] for k in range(num_nodes)  ]
                    r = np.array(r)
                    r = np.repeat(r, d, axis=1)
                    u = 2*np.random.randint(0,2, num_nodes*d).reshape(num_nodes ,d) - 1
                    r = u * r
                    mass = np.random.uniform(0,1,num_nodes *(d-1) ).reshape(num_nodes ,d -1)
                    mass = np.sort(mass)
                    mass = np.concatenate( [np.zeros((num_nodes,1)),mass, np.ones((num_nodes,1))],axis = 1)
                    mass = mass[:,1:] - mass[:,:-1]

                    if distance == "L1" :
                        # L1 Loss
                        Z = mass*r
                    elif distance == "L2" :
                        # L2 Loss
                        Z = np.sqrt(mass)*r
                    
                
                elif distance == "Linf" :
                    sampled_distance = sample_radius(radius, d) 
                    i_0 = np.random.randint(num_nodes)
                    r = [[sample_radius(sampled_distance, d)] if k!= i_0  else [sampled_distance] for k in range(num_nodes)  ]
                    r = np.array(r)
                    r = np.repeat(r, d, axis=1)
                    u = 2*np.random.randint(0,2, num_nodes*d).reshape(num_nodes ,d) - 1
                    r = u * r
                    mass = np.random.uniform(0,1,num_nodes *(d) ).reshape(num_nodes ,d )
                    max_index = [np.random.randint(d) for i_ in range(num_nodes)]
                    for i_ in range(num_nodes) :
                        mass[i_, max_index[i_]] = 1
                    Z = mass*r

                    # L infini Loss

                new_data = data.clone()
                Z =  torch.tensor(Z).to(device)
                noied_x = data.x + Z
                new_data.x = noied_x.float()
                noised_y =  model(new_data.x, norm_adj)
                noised_y = torch.softmax(noised_y, -1)

                #Distance, largest singular value
                norm = LA.matrix_norm(ground_thruth - noised_y, ord=2).item()
                norm = norm/(2 * np.sqrt(n_nodes))
                for sig in all_sigma : 
                    correct_sigma[sig] =  correct_sigma[sig] + (norm>sig)*1
            for sig in all_sigma : 
                global_robustness_per_sigma[sig] = global_robustness_per_sigma[sig] + [ correct_sigma[sig]/num_samples ]
            
        for sig in all_sigma : 
            global_robustness_per_sigma[sig] = np.mean(global_robustness_per_sigma[sig])

        return global_robustness_per_sigma



parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type = str, default='Cora', help='Dataset to use')
parser.add_argument('--radius', type = float, default=10)
parser.add_argument('--distance', type = str, default= "Linf" , choices=["L1", "L2" , "Linf"] )
parser.add_argument('--sigma', type = float, default= 10 )

args = parser.parse_args()

dataset = args.dataset
if args.dataset == "CS" :
        dataset = Coauthor(root="./data/", name=args.dataset, transform=T.NormalizeFeatures())
        train_mask, val_mask, test_mask = split(dataset, split_type="random", num_train_per_class=20, num_val=500, num_test=1000)
        data = dataset[0]
        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask
else :
        dataset = Planetoid("./data/", args.dataset, transform=T.NormalizeFeatures())
        data = dataset[0]
num_class = dataset.num_classes


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

all_sigma = np.arange(0,1,0.01).tolist()
all_global_robustness = {}
for sig in all_sigma : 
    all_global_robustness[sig] = []
num_exp = 10
for exp in range(num_exp) :
    checkpoint_path = "./checkpoints/gcorn_best_model_{}_fold_{}.pth".format(args.dataset , exp)
    loaded_checkpoint = torch.load( os.path.abspath(checkpoint_path))
    model = GCORN(loaded_checkpoint["num_features"], loaded_checkpoint["hidden_channels"], loaded_checkpoint["num_classes"]).to(device)
    model.load_state_dict(loaded_checkpoint["model_state_dict"])

    global_robustness_per_sigma = eval_rob(model, data, args.radius, args.distance, args.sigma)
    for sig in all_sigma : 
        all_global_robustness[sig] = all_global_robustness[sig] + [global_robustness_per_sigma[sig]]
for sig in all_sigma : 
    all_global_robustness[sig] = np.mean(all_global_robustness[sig] )

for sig in all_sigma : 
        print(" Sigma  :  {}  : ======>     Adf[f]    : {} % ".format(np.around(sig, 2) , 100 * all_global_robustness[sig] ))
