import sys
import torch
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from DeepRobust.deeprobust.graph.defense.gcn import GCN
# from DeepRobust.deeprobust.graph.noisy_gcn import Noisy_GCN
from DeepRobust.deeprobust.graph.targeted_attack import Nettack
from DeepRobust.deeprobust.graph.utils import *
from DeepRobust.deeprobust.graph.data import Dataset
import wandb
import argparse
from DeepRobust.deeprobust.graph import utils
from tqdm import tqdm
from scipy import sparse
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
# from DeepRobust.deeprobust.graph.defense import *
from DeepRobust.deeprobust.graph.global_attack import MetaApprox, Metattack, DICE

import argparse
from scipy.sparse import csr_matrix
import pickle
from sklearn.metrics import jaccard_score
from sklearn.preprocessing import normalize
import scipy
import numpy as np

from sklearn.preprocessing import normalize

import warnings
warnings.filterwarnings("ignore")


parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False,help='Disables CUDA training.')
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200,help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.01,help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4,help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
parser.add_argument('--dropout', type=float,     default=0.5,help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset', type=str, default='polblogs', choices=['cora','cora_ml', 'citeseer', 'polblogs', 'pubmed','cs','acm'], help='dataset')
parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
parser.add_argument('--model', type=str, default='Meta-Self',choices=['A-Meta-Self', 'Meta-Self'], help='model variant')
parser.add_argument('--modelname', type=str, default='GCN', choices=['GCN', 'GAT','GIN', 'JK'])
parser.add_argument('--defensemodel', type=str, default='GCNJaccard',choices=['GCNJaccard', 'RGCN', 'GCNSVD'])
parser.add_argument('--GNNGuard', type=bool, default=False, choices=[True, False])
parser.add_argument('--beta_max', type=float, default=0.15)
parser.add_argument('--device', type=int, default=0,help='Set CUDA device number; if set to -1, disables cuda.') 
parser.add_argument('--beta_min', type=float, default=0.01)
parser.add_argument('--use_wandb', type= bool,default = False , choices=[True, False])


args = parser.parse_args()
device = torch.device('cuda:'+str(args.device)) if torch.cuda.is_available() else torch.device('cpu')

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if device != 'cpu':
    torch.cuda.manual_seed(args.seed)


def test_model(adj, modified_adj):
        """
        Main function to test our proposed NoisyGCN
        ---
        Inputs:
            new_adj: the clean/perturbed adjacency to be tested

        Output:
            acc_test: The resulting accuracy test
        """
        best_acc_val = 0
        # We test the best noise value based on the validation nodes as specified
        # in the main paper
    
        classifier = GCN(nfeat=features.shape[1], nhid=16,
                                nclass=labels.max().item() + 1, dropout=0.5,
                                    device=device)

        classifier = classifier.to(device)

        classifier.fit(features, adj, labels, idx_train, train_iters=200,
                       idx_val=idx_val,
                       idx_test=idx_test,
                       verbose=False, attention=False)
        classifier.eval()

        # Validation Acc
        acc_val= classifier.test(idx_val)

        if acc_val > best_acc_val:
            best_acc_val = acc_val
            acc_test = classifier.test(idx_test)
        modified_output = classifier.predict(features, modified_adj)
        acc_noise_attacked = utils.accuracy(modified_output[idx_test], labels[idx_test]).item()
        return acc_test, acc_noise_attacked



def test(adj, defense="GCN"):
    """
    Main function to test the considered benchmarks
    ---
    Inputs:
        adj: the clean/perturbed adjacency to be tested
        defense (str,): The considered defense method (Guard, Jaccard ..)

    Output:
        acc_test: The resulting accuracy test
    """

    if defense == "GCN":
        classifier = globals()[args.modelname](nfeat=features.shape[1],
            nhid=16, nclass=labels.max().item() + 1, dropout=0.5, device=device)
        attention = False

    elif defense == "Guard":
        classifier = globals()[args.modelname](nfeat=features.shape[1],
            nhid=16, nclass=labels.max().item() + 1, dropout=0.5, device=device)
        attention = True

    else:
        classifier = globals()[defense](nnodes=adj.shape[0], nhid=16,
                        nfeat=features.shape[1], nclass=labels.max().item() + 1,
                                                    dropout=0.5, device=device)
        attention = False

    classifier = classifier.to(device)

    classifier.fit(features, adj, labels, idx_train, train_iters=201,
                   idx_val=idx_val, idx_test=idx_test, verbose=False,
                                                attention=attention)
    
    classifier.eval()

    acc_test, _ = classifier.test(idx_test)
    return acc_test.item()



if __name__ == '__main__':
    clean_acc = []
    attack_acc = []
    for training_id in range(10) : 

        # Load the Dataset
        data = Dataset(root='/tmp/', name=args.dataset)
        adj, features, labels = data.adj, data.features, data.labels
        # Extract the Train/Val/Test idx
        idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
        idx_unlabeled = np.union1d(idx_val, idx_test)
        if scipy.sparse.issparse(features)==False:
            features = scipy.sparse.csr_matrix(features)

        # Transforming the perturbation rate into edges
        perturbations = int(args.ptb_rate * (adj.sum()//2))
        # perturbations = 3

        # Preprocessing and sparsifying the adjacency and the feature matrix
        adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
        adj, features = csr_matrix(adj), csr_matrix(features)

        # Transform to undirected adjacency (spacially useful for OGB Data)
        adj = adj + adj.T
        adj[adj>1] = 1


        modified_adj_sparse = sparse.load_npz("./attacked_graph/dice_1/modified_adj_sparse_{}_{}.npz".format(args.dataset,training_id ))
        attention=False
        acc_noise_clean, acc_noise_attacked=test_model(adj, modified_adj_sparse)
 

        print('---------------')
        print("Non Attacked Acc - {}" .format(acc_noise_clean))
        print("Attacked Acc - {}" .format(acc_noise_attacked))
        print('---------------')
        clean_acc.append(acc_noise_clean)
        attack_acc.append(acc_noise_attacked)

