import argparse
import os.path as osp
import scipy
import copy
import numpy as np
import wandb
import torch
from DeepRobust.deeprobust.graph import utils
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, Coauthor
from DeepRobust.deeprobust.graph.data import Dataset
from torch_geometric.utils import add_self_loops, degree, to_dense_adj
from torch_geometric.datasets import CitationFull
from DeepRobust.deeprobust.graph.utils import *
from DeepRobust.deeprobust.graph.defense.gcn import GCN
# from utils import *
import pickle
import time
from scipy import sparse
from scipy.sparse import csr_matrix
# from r_gcn import *
import sys
from tree_inference_structure import Build_CRF_Tree_Structure, CRF_inference_Structure

def test_CRF(adj, l_acc_crf):
    """
    Main function to test our proposed NoisyGCN
    ---
    Inputs:
        new_adj: the clean/perturbed adjacency to be tested

    Output:
        acc_test: The resulting accuracy test
    """
    best_acc_val = 0
    # We test the best noise value based on the validation nodes as specified
    # in the main paper
    classifier = GCN(nfeat=features.shape[1], nhid=16,
                                nclass=labels.max().item() + 1, dropout=0.5,
                                    device=device)

    classifier = classifier.to(device)

    classifier.fit(features, adj, labels, idx_train, train_iters=200,
                       idx_val=idx_val,
                       idx_test=idx_test,
                       verbose=False, attention=False)
    classifier.eval()

    # Validation Acc
    acc_val= classifier.test(idx_val)

    if acc_val > best_acc_val:
            best_acc_val = acc_val
            acc_test = classifier.test(idx_test)

    adj_tree = Build_CRF_Tree_Structure(classifier, features, adj, labels, idx_train,idx_val,idx_test,  radius, args.num_samples, args.num_iteration, device)
    for sigma in np.arange(0.05, 1, 0.05):
        y_hat_CRF = CRF_inference_Structure(adj_tree, '0', sigma=sigma)
        acc_test_CRF = utils.accuracy(y_hat_CRF[idx_test], labels[idx_test]).item()
        l_acc_crf[sigma] = l_acc_crf[sigma] + [acc_test_CRF]
    del adj_tree

    return l_acc_crf

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='cora', choices=['cora','cora_ml', 'citeseer', 'polblogs', 'pubmed','cs','acm'], help='dataset')
    parser.add_argument('--hidden_channels', type=int, default=16)
    parser.add_argument('--seed', type=int, default=15, help='Random seed.')
    parser.add_argument('--lr', type=float, default=0.01)   
    parser.add_argument('--num_exp', type=int, default=3)
    parser.add_argument('--epochs', type=int, default=300)
    parser.add_argument('--attack', type=str, default ="random", help='Type of attack')
    parser.add_argument('--num_iteration', type=int, default=2, help="number of iteration of the CRF inference")
    parser.add_argument('--num_samples', type=int, default=5, help="number of samples of the CRF inference")
    parser.add_argument('--radius_prob', type=float, default=0.01, help="radius")
    parser.add_argument('--device', type=int, default=0,help='Set CUDA device number; if set to -1, disables cuda.') 


    args = parser.parse_args()
    device = torch.device('cuda:'+str(args.device)) if torch.cuda.is_available() else torch.device('cpu')
    #device = "mps" if torch.backends.mps.is_available() else "cpu"
    # num_exp = 10

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if device != 'cpu':
        torch.cuda.manual_seed(args.seed)

    
    # Load the Dataset
    data = Dataset(root='/tmp/', name=args.dataset)
    adj, features, labels = data.adj, data.features, data.labels


    # Extract the Train/Val/Test idx
    idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test

    idx_unlabeled = np.union1d(idx_val, idx_test)
    if scipy.sparse.issparse(features)==False:
        features = scipy.sparse.csr_matrix(features)
    
    # Preprocessing and sparsifying the adjacency and the feature matrix
    adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
    adj, features = csr_matrix(adj), csr_matrix(features)
    number_of_edges = adj.nnz
    radius = int(args.radius_prob * number_of_edges)
    # Transform to undirected adjacency (spacially useful for OGB Data)
    adj = adj + adj.T
    adj[adj>1] = 1


    l_acc_crf = {}
    for sigma in np.arange(0.05, 1, 0.05):
        l_acc_crf[sigma] = []
    
    for exp in range(args.num_exp):
        print('in')
        # Generate random noise attack
        
        adj_dice_1 = sparse.load_npz("./attacked_graph/dice_1/modified_adj_sparse_{}_{}.npz".format(args.dataset,exp ))
        l_acc_crf = test_CRF(adj_dice_1, l_acc_crf)        
        print(l_acc_crf)
        


        
    for sigma in np.arange(0.05, 1, 0.05):
 
        print('For GCN CRF sigma {}: {} +- {}' .format(sigma, np.mean(l_acc_crf[sigma]) * 100, np.std(l_acc_crf[sigma]) * 100))

    print('---')
 