"""
Contains the complete implementation to reproduce the results of the attacks
results in the case of the "Cora", "CiteSeer" and "PubMed" datataset.
---
The implementation contains all the related benchmarks:
    - GCN
    - RGCN
    - ParsevalR
    - GCORN

For the GCN-K and the Air-GNN, refer to their official repo.
"""

import argparse
import os.path as osp

import copy
import numpy as np
import wandb
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, Coauthor

from torch_geometric.utils import add_self_loops, degree, to_dense_adj
from torch_geometric.datasets import CitationFull

from utils import *
import pickle
import time
from attack import RandomNoise

from gcn import GCN, normalize_tensor_adj
# from r_gcn import *

from inference import CRF_inference
import sys
from tree_inference import Build_CRF_Tree, CRF_inference
from deeprobust.graph.targeted_attack import Nettack
from utils import PGD

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='CS')
    parser.add_argument('--hidden_channels', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--attack', type=str, default ="random", help='Type of attack')
    parser.add_argument('--device', type=int, default=0,help='Set CUDA device number; if set to -1, disables cuda.') 


    args = parser.parse_args()
    device = torch.device('cuda:'+str(args.device)) if torch.cuda.is_available() else torch.device('cpu')
    #device = "mps" if torch.backends.mps.is_available() else "cpu"
    num_exp = 10
    # num_exp = 10


    if args.dataset == "CS":
        dataset = Coauthor(root="./data/", name="CS",
                                transform=T.NormalizeFeatures())

    # Loading the dataset
    if args.dataset == "CS":
            dataset = Coauthor(root="./data/", name="CS", transform=T.NormalizeFeatures())
    elif args.dataset == "cora_ml":
        dataset = CitationFull("./data/", args.dataset, transform=T.NormalizeFeatures())
    else:
        dataset = Planetoid("./data/", args.dataset, transform=T.NormalizeFeatures())


    data = dataset[0]
    data = data.to(device)
    adj_true = to_dense_adj(data.edge_index)[0, :,:]
    norm_adj = normalize_tensor_adj(adj_true)

    #l_acc_gcn = []
    #l_acc_gcn_attacked = []

    l_acc = []
    l_acc_attacked_0_5 = []
    l_acc_attacked_1 = []
    l_acc_attacked_pgd = []

    for exp in range(num_exp):
        print('in')
        # Generate random noise attack
        random_noise_0_5 = RandomNoise(0.5)
        perturbed_x_0_5 = random_noise_0_5.perturb(data)
        data_perturbed_0_5 = copy.deepcopy(data)
        data_perturbed_0_5.x = perturbed_x_0_5

        random_noise_1 = RandomNoise(1.0)
        perturbed_x_1 = random_noise_0_5.perturb(data)
        data_perturbed_1 = copy.deepcopy(data)
        data_perturbed_1.x = perturbed_x_1

        path_model = "./features_attacks/Models/GCN_{}_{}.pth".format(args.dataset ,str(exp + 1) )
        model_gcn = GCN(dataset.num_features, args.hidden_channels,
                                            dataset.num_classes).to(device)
        
        # Load Models
        loaded_checkpoint = torch.load(path_model)
        model_gcn.load_state_dict(loaded_checkpoint['model_state_dict'])
        data = loaded_checkpoint['data']
        # This line is very important 
        model_gcn.eval()

        # CRF clean accuracy
        y_hat = model_gcn(data.x, norm_adj)
        y_hat = y_hat.argmax(dim=-1).to(device)
        acc = int((y_hat[data.test_mask] == data.y[data.test_mask]).sum()) / int(data.test_mask.sum())
        l_acc = l_acc + [acc]

        # CRF Random accuracy Budget = 0.5 
        y_hat_0_5 = model_gcn(data_perturbed_0_5.x, norm_adj)
        y_hat_0_5 = y_hat_0_5.argmax(dim=-1).to(device)
        acc_crf_0_5 = int((y_hat_0_5[data.test_mask] == data.y[data.test_mask]).sum()) / int(data.test_mask.sum())
        l_acc_attacked_0_5 = l_acc_attacked_0_5 + [acc_crf_0_5]

        # CRF Random accuracy Budget = 0.5 
        y_hat_1 = model_gcn(data_perturbed_1.x, norm_adj)
        y_hat_1= y_hat_1.argmax(dim=-1).to(device)
        acc_crf_1 = int((y_hat_1[data.test_mask] == data.y[data.test_mask]).sum()) / int(data.test_mask.sum())
        l_acc_attacked_1 = l_acc_attacked_1 + [acc_crf_1]


        # PJD
        perturbation = PGD(model_gcn, data, norm_adj, 0.1).attack()
        data_perturbed_pgd = copy.deepcopy(data)
        data_perturbed_pgd.x = perturbation.data + data.x
        y_hat_pgd = model_gcn(data_perturbed_pgd.x, norm_adj)
        y_hat_pgd = y_hat_pgd.argmax(dim=-1).to(device)
        acc_crf_pgd = int((y_hat_pgd[data.test_mask] == data.y[data.test_mask]).sum()) / int(data.test_mask.sum())
        l_acc_attacked_pgd = l_acc_attacked_pgd + [acc_crf_pgd]
        
        del y_hat, model_gcn


    print('For GCN : {} +- {}' .format( np.mean(l_acc) * 100, np.std(l_acc) * 100))
    print('For GCN Peturbed Budget 0.5 : {} +- {}' .format(np.mean(l_acc_attacked_0_5) * 100, np.std(l_acc_attacked_0_5) * 100))
    print('For GCN Peturbed Budget 1.0: {} +- {}' .format(np.mean(l_acc_attacked_1) * 100, np.std(l_acc_attacked_1) * 100))
    print('For GCN PGD: {} +- {}' .format(np.mean(l_acc_attacked_pgd) * 100, np.std(l_acc_attacked_pgd) * 100))

    print('---')
