import torch
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from deeprobust.graph.defense import GCN
from deeprobust.graph.targeted_attack import Nettack
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset
import argparse
from tqdm import tqdm
import grb.utils as utils
import os

from grb.utils.normalize import GCNAdjNorm, SAGEAdjNorm

from grb.dataset import CustomDataset
from scipy.sparse import csr_matrix
import statistics

import sys
import json

import random
from graphcon import GNN_graphcon

from torch_geometric.utils import to_scipy_sparse_matrix,dense_to_sparse
from tqdm import tqdm, trange

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
parser.add_argument('--cuda', type=int, default=0, help='cuda.')
parser.add_argument('--gpu', type=int, default=0, help='gpu.')
parser.add_argument('--defence', type=str, default='hamgcnv4', help='model variant')
parser.add_argument('--n_pert', type=int, default=None, help='n_perturbations.')

parser.add_argument('--hidden', type=int, default=64,
                    help='Number of hidden units.')
parser.add_argument('--runtime', type=int, default=10,
                    help='runtime.')
parser.add_argument('--lr', type=float, default=0.001,
                    help='lr.')

###### args for pde model ###################################

parser.add_argument('--hidden_dim', type=int, default=256, help='Hidden dimension.')
parser.add_argument('--proj_dim', type=int, default=256, help='proj_dim dimension.')
parser.add_argument('--fc_out', dest='fc_out', action='store_true',
                    help='Add a fully connected layer to the decoder.')
parser.add_argument('--input_dropout', type=float, default=0.0, help='Input dropout rate.')
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
parser.add_argument("--batch_norm", dest='batch_norm', action='store_true', help='search over reg params')
parser.add_argument('--optimizer', type=str, default='adam', help='One from sgd, rmsprop, adam, adagrad, adamax.')
# parser.add_argument('--lr', type=float, default=0.005, help='Learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay for optimization')
parser.add_argument('--epoch', type=int, default=500, help='Number of training epochs per iteration.')
parser.add_argument('--alpha', type=float, default=1.0, help='Factor in front matrix A.')
parser.add_argument('--alpha_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) alpha')
parser.add_argument('--no_alpha_sigmoid', dest='no_alpha_sigmoid', action='store_true',
                    help='apply sigmoid before multiplying by alpha')
parser.add_argument('--beta_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) beta')
parser.add_argument('--block', type=str, default='constant', help='constant, mixed, attention, hard_attention')
parser.add_argument('--function', type=str, default='laplacian', help='laplacian, transformer, dorsey, GAT')
parser.add_argument('--use_mlp', dest='use_mlp', action='store_true',
                    help='Add a fully connected layer to the encoder.')
parser.add_argument('--add_source', dest='add_source', action='store_true',
                    help='If try get rid of alpha param and the beta*x0 source term')

# ODE args
parser.add_argument('--time', type=float, default=3.0, help='End time of ODE integrator.')
parser.add_argument('--augment', action='store_true',
                    help='double the length of the feature vector by appending zeros to stabilist ODE learning')
parser.add_argument('--method', type=str, default='euler',
                    help="set the numerical solver: dopri5, euler, rk4, midpoint")
parser.add_argument('--step_size', type=float, default=1.0,
                    help='fixed step size when using fixed step solvers e.g. rk4')
parser.add_argument('--max_iters', type=float, default=100000, help='maximum number of integration steps')
parser.add_argument("--adjoint_method", type=str, default="adaptive_heun",
                    help="set the numerical solver for the backward pass: dopri5, euler, rk4, midpoint")
parser.add_argument('--adjoint', dest='adjoint', action='store_true',
                    help='use the adjoint ODE method to reduce memory footprint')
parser.add_argument('--adjoint_step_size', type=float, default=1,
                    help='fixed step size when using fixed step adjoint solvers e.g. rk4')
parser.add_argument('--tol_scale', type=float, default=1., help='multiplier for atol and rtol')
parser.add_argument("--tol_scale_adjoint", type=float, default=1.0,
                    help="multiplier for adjoint_atol and adjoint_rtol")
parser.add_argument('--ode_blocks', type=int, default=1, help='number of ode blocks to run')
parser.add_argument("--max_nfe", type=int, default=1000,
                    help="Maximum number of function evaluations in an epoch. Stiff ODEs will hang if not set.")
parser.add_argument("--no_early", action="store_true",
                    help="Whether or not to use early stopping of the ODE integrator when testing.")
parser.add_argument('--earlystopxT', type=float, default=3, help='multiplier for T used to evaluate best model')
parser.add_argument("--max_test_steps", type=int, default=100,
                    help="Maximum number steps for the dopri5Early test integrator. "
                         "used if getting OOM errors at test time")

# Attention args
parser.add_argument('--leaky_relu_slope', type=float, default=0.2,
                    help='slope of the negative part of the leaky relu used in attention')
parser.add_argument('--attention_dropout', type=float, default=0., help='dropout of attention weights')
parser.add_argument('--heads', type=int, default=4, help='number of attention heads')
parser.add_argument('--attention_norm_idx', type=int, default=0, help='0 = normalise rows, 1 = normalise cols')
parser.add_argument('--attention_dim', type=int, default=16,
                    help='the size to project x to before calculating att scores')
parser.add_argument('--mix_features', dest='mix_features', action='store_true',
                    help='apply a feature transformation xW to the ODE')
parser.add_argument('--reweight_attention', dest='reweight_attention', action='store_true',
                    help="multiply attention scores by edge weights before softmax")
parser.add_argument('--attention_type', type=str, default="scaled_dot",
                    help="scaled_dot,cosine_sim,pearson, exp_kernel")
parser.add_argument('--square_plus', action='store_true', help='replace softmax with square plus')

parser.add_argument('--data_norm', type=str, default='gcn',
                    help='rw for random walk, gcn for symmetric gcn norm')
parser.add_argument('--self_loop_weight', type=float, default=1.0, help='Weight of self-loops.')
parser.add_argument('--patience', type=int, default=100, help='Weight of self-loops.')



args = parser.parse_args()
args.cuda = torch.cuda.is_available()
# print('cuda: %s' % args.cuda)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device('cuda',args.gpu)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

data = Dataset(root='/tmp/', name=args.dataset,setting='nettack', seed=args.seed)
adj, features, labels = data.adj, data.features, data.labels

idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test

idx_unlabeled = np.union1d(idx_val, idx_test)

# Setup Surrogate model
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
                nhid=16, dropout=0, with_relu=False, with_bias=False, device=device)

surrogate = surrogate.to(device)
surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)

# Setup Attack Model
target_node =  859
assert target_node in idx_unlabeled

# model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=False, device=device)
model = model.to(device)

def main():
    degrees = adj.sum(0).A1
    # How many perturbations to perform. Default: Degree of the node
    n_perturbations = int(degrees[target_node])

    # direct attack
    model.attack(features, adj, labels, target_node, n_perturbations)
    # # indirect attack/ influencer attack
    # model.attack(features, adj, labels, target_node, n_perturbations, direct=False, n_influencers=5)
    modified_adj = model.modified_adj
    modified_features = model.modified_features
    print(model.structure_perturbations)
    print('=== testing GCN on original(clean) graph ===')
    test(adj, features, target_node)
    print('=== testing GCN on perturbed graph ===')
    test(modified_adj, modified_features, target_node)

def test(adj, features, target_node):
    ''' test on GCN '''
    gcn = GCN(nfeat=features.shape[1],
              nhid=16,
              nclass=labels.max().item() + 1,
              dropout=0.5, device=device)

    gcn = gcn.to(device)

    gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)

    gcn.eval()
    output = gcn.predict()
    probs = torch.exp(output[[target_node]])[0]
    print('Target node probs: {}'.format(probs.detach().cpu().numpy()))
    acc_test = accuracy(output[idx_test], labels[idx_test])

    print("Overall test set results:",
          "accuracy= {:.4f}".format(acc_test.item()))

    return acc_test.item()

def accuracy_1(output, labels):
    """"""
    try:
        num = len(labels)
    except:
        num = 1

    if type(labels) is not torch.Tensor:
        labels = torch.LongTensor([labels])

    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / num, preds, labels

def select_nodes(target_gcn=None):
    '''
    selecting nodes as reported in nettack paper:
    (i) the 10 nodes with highest margin of classification, i.e. they are clearly correctly classified,
    (ii) the 10 nodes with lowest margin (but still correctly classified) and
    (iii) 20 more nodes randomly
    '''

    if target_gcn is None:
        target_gcn = GCN(nfeat=features.shape[1],
                  nhid=16,
                  nclass=labels.max().item() + 1,
                  dropout=0.5, device=device)
        target_gcn = target_gcn.to(device)
        target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
    target_gcn.eval()
    output = target_gcn.predict()
    degrees = adj.sum(0).A1

    margin_dict = {}
    for idx in idx_test:
        margin = classification_margin(output[idx], labels[idx])

        acc, _, _ = accuracy_1(output[[idx]], labels[idx])
        if acc == 0 or int(degrees[idx]) < 1:  # only keep the correctly classified nodes
            continue
        """check the outliers:"""
        neighbours = list(adj.todense()[idx].nonzero()[1])
        y = [labels[i] for i in neighbours]
        node_y = labels[idx]
        aa = node_y == y
        outlier_score = 1 - aa.sum() / len(aa)
        if outlier_score >= 0.5:
            continue

        # if margin < 0: # only keep the nodes correctly classified
        #     continue
        margin_dict[idx] = margin
    sorted_margins = sorted(margin_dict.items(), key=lambda x:x[1], reverse=True)
    high = [x for x, y in sorted_margins[: 10]]
    low = [x for x, y in sorted_margins[-10: ]]
    other = [x for x, y in sorted_margins[10: -10]]
    other = np.random.choice(other, 20, replace=False).tolist()

    return high + low + other

def multi_test_poison():
    # test on 40 nodes on poisoining attack
    cnt = 0
    cnt_defend =0
    degrees = adj.sum(0).A1
    node_list = select_nodes()
    num = len(node_list)
    print('=== [Poisoning] Attacking %s nodes respectively ===' % num)
    for target_node in tqdm(node_list):
        n_perturbations = int(degrees[target_node])
        model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
        model = model.to(device)
        model.attack(features, adj, labels, target_node, n_perturbations, verbose=False)
        modified_adj = model.modified_adj
        modified_features = model.modified_features
        acc = single_test(modified_adj, modified_features, target_node)
        print('test target node : %s' % (target_node))
        print("acc test gcn: ", acc)
        if acc == 0:
            cnt += 1
        acc_defend = single_test_defend(modified_adj, modified_features, target_node)
        print("acc test defend: ", acc_defend)
        if acc_defend == 0:
            cnt_defend +=1

        print('misclassification rate on GCN : %s' % (cnt / num))
        print('misclassification rate on defend model: %s' % (cnt_defend / num))
    print('misclassification rate on GCN : %s' % (cnt/num))
    print('misclassification rate on defend model: %s' % (cnt_defend / num))

    return cnt / num, cnt_defend / num

def single_test(adj, features, target_node, gcn=None):
    if gcn is None:
        # test on GCN (poisoning attack)
        gcn = GCN(nfeat=features.shape[1],
                  nhid=16,
                  nclass=labels.max().item() + 1,
                  dropout=0.5, device=device)

        gcn = gcn.to(device)

        gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
        gcn.eval()
        output = gcn.predict()
    else:
        # test on GCN (evasion attack)
        output = gcn.predict(features, adj)
    probs = torch.exp(output[[target_node]])

    # acc_test = accuracy(output[[target_node]], labels[target_node])
    acc_test = (output.argmax(1)[target_node] == labels[target_node])
    return acc_test.item()


def test_model(model,data):
    model.eval()
    accs = []
    with torch.no_grad():
        logits = model(data.features,data.adj)
        # for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        #     pred = logits[mask].max(1)[1]
        #     acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        #     accs.append(acc)
        for mask in [data.train_mask, data.val_mask, data.test_mask]:
            pred = logits[mask].max(1)[1]
            acc = pred.eq(data.labels[mask]).sum().item() / mask.sum().item()
            accs.append(acc)
    return  accs

def single_test_defend(adj, features,labels, target_node, defend_model=None):
    adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)

    # adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
    adj_csr = dense_to_sparse(adj)

    features = features.to(args.gpu)
    labels = labels.to(args.gpu)

    model = defend_model
    model.eval()
    with torch.no_grad():
        output = model(features, adj_csr)
    acc_test = (output.argmax(1)[target_node] == labels[target_node])




    return acc_test.item()


def train_defend(data):
    features = data.features
    adj = data.adj
    labels = data.labels
    train_mask = torch.zeros(features.shape[0], dtype=bool,device=args.gpu )
    train_mask[idx_train] = True
    test_mask = torch.zeros(features.shape[0], dtype=bool,device=args.gpu )
    test_mask[idx_test] = True
    val_mask = torch.zeros(features.shape[0], dtype=bool,device=args.gpu )
    val_mask[idx_val] = True
    # adj = adj.numpy()

    # features = features.todense()
    adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False)
    adj_csr = dense_to_sparse(adj)
    data.adj = adj_csr
    data.features = features.to(args.gpu)
    data.labels = labels.to(args.gpu)
    data.train_mask = train_mask
    data.test_mask = test_mask
    data.val_mask = val_mask

    opt = vars(args)
    opt['num_classes'] = labels.max().item() + 1
    model = GNN_graphcon(opt, features.shape[1], args.gpu)
    model = model.to(args.gpu)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    lf = torch.nn.CrossEntropyLoss()
    best_time = val_acc = test_acc = train_acc = best_epoch = 0
    save_dir = "./saved_net_model/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_name = args.function + args.dataset + str(args.gpu) + str(args.n_pert) + "_model_nettack.pt"
    epoch_bar = trange(args.epoch, ncols=100)
    for i in epoch_bar:
        model.train()
        optimizer.zero_grad()
        out = model(data.features, data.adj)
        loss = lf(out[data.train_mask], data.labels.squeeze()[data.train_mask])
        loss.backward()
        optimizer.step()
        # set tqdm description

        # print("Epoch: {:03d}, Train loss: {:.4f}".format(i, loss.item()))
        tmp_train_acc, tmp_val_acc, tmp_test_acc = test_model(model, data)
        if tmp_val_acc > val_acc:
            val_acc = tmp_val_acc
            test_acc = tmp_test_acc
            train_acc = tmp_train_acc
            best_epoch = i
            counter = 0
            # save model
            torch.save({'model': model.state_dict()}, save_dir + save_name)
        else:
            counter += 1
        # print("Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}".format(i, tmp_train_acc, tmp_val_acc, tmp_test_acc))
        epoch_bar.set_description(
            "Epoch: {:03d}, loss: {:.4f},Train: {:.4f}, Val: {:.4f}, Test: {:.4f}".format(i, loss.item(),
                                                                                          tmp_train_acc,
                                                                                          tmp_val_acc,
                                                                                          tmp_test_acc))
        if counter == args.patience:
            print("Early Stopping")
            break

    print(
        "Best Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}".format(best_epoch, train_acc, val_acc, test_acc))

    acc_test = test_acc





    ckp = torch.load(os.path.join(save_dir, save_name), map_location=device)
    model.load_state_dict(ckp['model'])
    # model.eval()
    # output = trainer.inference(model)
    #
    # probs = torch.exp(output[[target_node]])
    #
    # # acc_test = accuracy(output[[target_node]], labels[target_node])
    # acc_test = (output.argmax(1)[target_node] == labels[target_node])
    return model


def multi_test_evasion():
    # test on 40 nodes on evasion attack
    target_gcn = GCN(nfeat=features.shape[1],
              nhid=16,
              nclass=labels.max().item() + 1,
              dropout=0.5, device=device)

    target_gcn = target_gcn.to(device)

    target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)

    cnt = 0
    cnt_defend = 0
    degrees = adj.sum(0).A1
    node_list = select_nodes(target_gcn)
    num = len(node_list)

    model_defend = train_defend(data)

    print('=== [Evasion] Attacking %s nodes respectively ===' % num)
    for target_node in tqdm(node_list):
        if args.n_pert is None:
            n_perturbations = int(degrees[target_node])
        else:
            n_perturbations = args.n_pert

        model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
        model = model.to(device)
        model.attack(features, adj, labels, target_node, n_perturbations, verbose=False)
        modified_adj = model.modified_adj
        modified_features = model.modified_features

        acc = single_test(modified_adj, modified_features, target_node, gcn=target_gcn)

        if acc == 0:
            cnt += 1

        acc_defend = single_test_defend(modified_adj, modified_features,labels, target_node,defend_model=model_defend)
        print("acc test defend: ", acc_defend)
        if acc_defend == 0:
            cnt_defend += 1
        print('misclassification rate on GCN : %s' % (cnt / num))
        print('misclassification rate on defend model: %s' % (cnt_defend / num))


    print('misclassification rate on GCN : %s' % (cnt / num))
    print('misclassification rate on defend model: %s' % (cnt_defend / num))

    return cnt / num, cnt_defend / num

if __name__ == '__main__':
    import time
    import sys
    import json

    timestr = time.strftime("%Y%m%d-%H%M%S")
    if not os.path.exists("log_nettack"):
        os.makedirs("log_nettack")

    filename_log = "log_nettack/" + str(args.function) + args.dataset + "_Nettack_DI_evasion_" + timestr + ".txt"
    command_args = " ".join(sys.argv)
    with open(filename_log, 'a') as f:
        json.dump(command_args, f)
    mis_gcn =[]
    mis_def =[]
    acc_gcn =[]
    acc_def =[]
    for rt in range(args.runtime):
        main()
        # misclassification_gcn,misclassification_defend = multi_test_poison()
        misclassification_gcn, misclassification_defend = multi_test_evasion()
        mis_gcn.append(misclassification_gcn)
        mis_def.append(misclassification_defend)
        acc_gcn.append(1-misclassification_gcn)
        acc_def.append(1-misclassification_defend)
        if rt > 0:
            with open(filename_log, 'a') as f:
                f.write("\n")
                f.write("runtime: ")
                f.write(str(rt))
                f.write("\n")
                f.write("misclassification_gcn mean: ")
                f.write(str(statistics.mean(mis_gcn)))
                f.write(",")
                f.write((str(statistics.stdev(mis_gcn))))
                f.write("\n")

                f.write("misclassification_defend mean: ")
                f.write(str(statistics.mean(mis_def)))
                f.write(",")
                f.write((str(statistics.stdev(mis_def))))
                f.write("\n")

                f.write("acc_gcn mean: ")
                f.write(str(statistics.mean(acc_gcn)))
                f.write(",")
                f.write((str(statistics.stdev(acc_gcn))))
                f.write("\n")

                f.write("acc_def mean: ")
                f.write(str(statistics.mean(acc_def)))
                f.write(",")
                f.write((str(statistics.stdev(acc_def))))
                f.write("\n")

