import argparse
import os.path as osp

import copy
import numpy as np
import wandb
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, Coauthor

from torch_geometric.utils import add_self_loops, degree, to_dense_adj
from torch_geometric.datasets import CitationFull

from utils import *
import pickle
import time
from attack import RandomNoise
from gcn import GCN, normalize_tensor_adj
from utils import PGD
# from r_gcn import *
import sys
from tree_inference import Build_CRF_Tree, CRF_inference

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='Cora')
    parser.add_argument('--hidden_channels', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--num_exp', type=int, default=3, help='Number of experiences')
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--attack', type=str, default ="random", help='Type of attack')
    parser.add_argument('--num_iteration', type=int, default=2, help="number of iteration of the CRF inference")
    parser.add_argument('--num_samples', type=int, default=15, help="number of samples of the CRF inference")
    parser.add_argument('--radius', type=float, default=10.0, help="radius")
    parser.add_argument('--device', type=int, default=0,help='Set CUDA device number; if set to -1, disables cuda.') 
    args = parser.parse_args()
    device = torch.device('cuda:'+str(args.device)) if torch.cuda.is_available() else torch.device('cpu')
    #device = "mps" if torch.backends.mps.is_available() else "cpu"
    # num_exp = 10


    if args.dataset == "CS":
        dataset = Coauthor(root="./data/", name="CS",transform=T.NormalizeFeatures())
    elif args.dataset == "cora_ml":
        dataset = CitationFull("./data/", args.dataset, transform=T.NormalizeFeatures())
    else:
        dataset = Planetoid("./data/", args.dataset, transform=T.NormalizeFeatures())


    data = dataset[0]
    data = data.to(device)
    adj_true = to_dense_adj(data.edge_index)[0, :,:]
    norm_adj = normalize_tensor_adj(adj_true)

    #l_acc_gcn = []
    #l_acc_gcn_attacked = []

    l_acc_crf = {}
    l_acc_crf_attacked_0_5 = {}
    l_acc_crf_attacked_1 = {}
    l_acc_crf_pgd = {}
    for sigma in np.arange(0.05, 1, 0.05):
        l_acc_crf[sigma] = []
        l_acc_crf_attacked_0_5[sigma] = []
        l_acc_crf_attacked_1[sigma] = []
        l_acc_crf_pgd[sigma] = []
    for exp in range(args.num_exp):
        print('in')
        # Generate random noise attack
        random_noise_0_5 = RandomNoise(0.5)
        perturbed_x_0_5 = random_noise_0_5.perturb(data)
        data_perturbed_0_5 = copy.deepcopy(data)
        data_perturbed_0_5.x = perturbed_x_0_5

        random_noise_1 = RandomNoise(1.0)
        perturbed_x_1 = random_noise_0_5.perturb(data)
        data_perturbed_1 = copy.deepcopy(data)
        data_perturbed_1.x = perturbed_x_1

        path_model = "./Models/GCN_{}_{}.pth".format(args.dataset ,str(exp + 1) )
        model_gcn = GCN(dataset.num_features, args.hidden_channels,dataset.num_classes).to(device)
        
        # Load Models
        loaded_checkpoint = torch.load(path_model)
        model_gcn.load_state_dict(loaded_checkpoint['model_state_dict'])
        data = loaded_checkpoint['data']
        # This line is very important 
        model_gcn.eval()

        #acc_1, acc_2, h_1, h_2 = compute_acc_perturbation(model_gcn, data,data_perturbed, norm_adj)
        # CRF clean accuracy
        test_tree = Build_CRF_Tree(model_gcn, data, norm_adj,  args.radius, args.num_samples, args.num_iteration, device)
        for sigma in np.arange(0.05, 1, 0.05):
            y_hat = CRF_inference(test_tree, '0', sigma=sigma)
            y_hat = y_hat.argmax(dim=-1).to(device)
            acc_crf = int((y_hat[data.test_mask] == data.y[data.test_mask]).sum()) / int(data.test_mask.sum())
            l_acc_crf[sigma] = l_acc_crf[sigma] + [acc_crf]


        # CRF Random accuracy Budget = 0.5 
        test_tree_perturbed_0_5 = Build_CRF_Tree(model_gcn, data_perturbed_0_5, norm_adj,  args.radius, args.num_samples, args.num_iteration, device)
        for sigma in np.arange(0.05, 1, 0.05):
            y_hat_perturbed = CRF_inference(test_tree_perturbed_0_5, '0', sigma=sigma)
            y_hat_perturbed = y_hat_perturbed.argmax(dim=-1).to(device)
            acc_crf_perturbed = int((y_hat_perturbed[data.test_mask] == data.y[data.test_mask]).sum()) / int(data.test_mask.sum())
            l_acc_crf_attacked_0_5[sigma] = l_acc_crf_attacked_0_5[sigma] + [acc_crf_perturbed]


        # CRF Random accuracy Budget = 1.0 
        test_tree_perturbed_1 = Build_CRF_Tree(model_gcn, data_perturbed_1, norm_adj,  args.radius, args.num_samples, args.num_iteration, device)
        for sigma in np.arange(0.05, 1, 0.05):
            y_hat_perturbed = CRF_inference(test_tree_perturbed_1, '0', sigma=sigma)
            y_hat_perturbed = y_hat_perturbed.argmax(dim=-1).to(device)
            acc_crf_perturbed = int((y_hat_perturbed[data.test_mask] == data.y[data.test_mask]).sum()) / int(data.test_mask.sum())
            l_acc_crf_attacked_1[sigma] = l_acc_crf_attacked_1[sigma] + [acc_crf_perturbed]


        # PJD
        perturbation = PGD(model_gcn, data, norm_adj, 0.1).attack()
        data_perturbed_pgd = copy.deepcopy(data)
        data_perturbed_pgd.x = perturbation.data + data.x
        test_tree_PGD = Build_CRF_Tree(model_gcn, data_perturbed_pgd, norm_adj,  args.radius, args.num_samples, args.num_iteration, device)
        for sigma in np.arange(0.05, 1, 0.05):
            y_hat_perturbed = CRF_inference(test_tree_PGD, '0', sigma=sigma)
            y_hat_perturbed = y_hat_perturbed.argmax(dim=-1).to(device)
            acc_crf_perturbed = int((y_hat_perturbed[data_perturbed_pgd.test_mask] == data.y[data_perturbed_pgd.test_mask]).sum()) / int(data_perturbed_pgd.test_mask.sum())
            l_acc_crf_pgd[sigma] = l_acc_crf_pgd[sigma] + [acc_crf_perturbed]

        del test_tree, test_tree_perturbed_0_5, test_tree_perturbed_1, test_tree_PGD, y_hat, model_gcn


        
    for sigma in np.arange(0.05, 1, 0.05):

        print('For GCN CRF sigma {}: {} +- {}' .format(sigma, np.mean(l_acc_crf[sigma]) * 100, np.std(l_acc_crf[sigma]) * 100))
        print('For GCN CRF Peturbed Budget 0.5 sigma {}: {} +- {}' .format(sigma,np.mean(l_acc_crf_attacked_0_5[sigma]) * 100, np.std(l_acc_crf_attacked_0_5[sigma]) * 100))
        print('For GCN CRF Peturbed Budget 01.0 sigma {}: {} +- {}' .format(sigma,np.mean(l_acc_crf_attacked_1[sigma]) * 100, np.std(l_acc_crf_attacked_1[sigma]) * 100))
        print('For GCN CRF PGD sigma {}: {} +- {}' .format(sigma,np.mean(l_acc_crf_pgd[sigma]) * 100, np.std(l_acc_crf_pgd[sigma]) * 100))
    print('---')
 