import torch
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from deeprobust.graph.defense import GCN
from deeprobust.graph.targeted_attack import FGA
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset
from tqdm import tqdm
import argparse
from my_utils.utils import spade,hnsw,construct_adj, spectral_embedding_eig,SPF,construct_weighted_adj,spade_nonetworkx,rank_samples_by_variance
from deeprobust.graph import utils
import copy
import warnings
warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
parser.add_argument('--ptb_rate', type=float, default=0.05,  help='pertubation rate')

args = parser.parse_args()
args.cuda = torch.cuda.is_available()
print('cuda: %s' % args.cuda)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

data = Dataset(root='./tmp/', name=args.dataset)
adj, features, labels = data.adj, data.features, data.labels
orig_adj, orig_features = copy.copy(adj), copy.copy(features.todense())
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test

idx_unlabeled = np.union1d(idx_val, idx_test)

# Setup Surrogate model
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
                nhid=16, device=device)

surrogate = surrogate.to(device)
surrogate.fit(features, adj, labels, idx_train, idx_val)

# Setup Attack Model
target_node = 0
model = FGA(surrogate, nnodes=adj.shape[0], device=device)
model = model.to(device)

def main():
    u = 0 # node to attack
    assert u in idx_unlabeled

    degrees = adj.sum(0).A1
    n_perturbations = int(degrees[u]) # How many perturbations to perform. Default: Degree of the node

    model.attack(features, adj, labels, idx_train, target_node, n_perturbations)

    print('=== testing GCN on original(clean) graph ===')
    test(adj, features, target_node)

    print('=== testing GCN on perturbed graph ===')
    test(model.modified_adj, features, target_node)

def test(adj, features, target_node):
    ''' test on GCN '''
    gcn = GCN(nfeat=features.shape[1],
              nhid=16,
              nclass=labels.max().item() + 1,
              dropout=0.5, device=device)

    if args.cuda:
        gcn = gcn.to(device)

    gcn.fit(features, adj, labels, idx_train)

    gcn.eval()
    output = gcn.predict()
    probs = torch.exp(output[[target_node]])[0]
    print('probs: {}'.format(probs.detach().cpu().numpy()))
    acc_test = accuracy(output[idx_test], labels[idx_test])

    print("Test set results:",
          "accuracy= {:.4f}".format(acc_test.item()))

    return acc_test.item()


def select_nodes(target_gcn=None):
    '''
    selecting nodes as reported in nettack paper:
    (i) the 10 nodes with highest margin of classification, i.e. they are clearly correctly classified,
    (ii) the 10 nodes with lowest margin (but still correctly classified) and
    (iii) 20 more nodes randomly
    '''

    if target_gcn is None:
        target_gcn = GCN(nfeat=features.shape[1],
                  nhid=16,
                  nclass=labels.max().item() + 1,
                  dropout=0.5, device=device)
        target_gcn = target_gcn.to(device)
        target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
    target_gcn.eval()
    output = target_gcn.predict()

    margin_dict = {}
    for idx in idx_test:
        margin = classification_margin(output[idx], labels[idx])
        if margin < 0: # only keep the nodes correctly classified
            continue
        margin_dict[idx] = margin
    sorted_margins = sorted(margin_dict.items(), key=lambda x:x[1], reverse=True)
    high = [x for x, y in sorted_margins[: 10]]
    low = [x for x, y in sorted_margins[-10: ]]
    other = [x for x, y in sorted_margins[10: -10]]
    other = np.random.choice(other, 20, replace=False).tolist()

    return high + low + other

def multi_test_poison():
    # test on 40 nodes on poisoining attack
    cnt = 0
    degrees = adj.sum(0).A1
    node_list = select_nodes()
    num = len(node_list)
    print('=== [Poisoning] Attacking %s nodes respectively ===' % num)
    for target_node in tqdm(node_list):
        n_perturbations = int(degrees[target_node])
        model = FGA(surrogate, nnodes=adj.shape[0], device=device)
        model = model.to(device)
        model.attack(features, adj, labels, idx_train, target_node, n_perturbations)
        modified_adj = model.modified_adj
        acc = single_test(modified_adj, features, target_node)
        if acc == 0:
            cnt += 1
    print('misclassification rate : %s' % (cnt/num))

def single_test(adj, features, target_node, gcn=None):
    if gcn is None:
        # test on GCN (poisoning attack)
        gcn = GCN(nfeat=features.shape[1],
                  nhid=16,
                  nclass=labels.max().item() + 1,
                  dropout=0.5, device=device)

        gcn = gcn.to(device)

        gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
        gcn.eval()
        output = gcn.predict()
    else:
        # test on GCN (evasion attack)
        output = gcn.predict(features, adj)
    probs = torch.exp(output[[target_node]])

    # acc_test = accuracy(output[[target_node]], labels[target_node])
    acc_test = (output.argmax(1)[target_node] == labels[target_node])
    return acc_test.item()

def multi_test_evasion():
    # test on 40 nodes on evasion attack
    # target_gcn = GCN(nfeat=features.shape[1],
    #           nhid=16,
    #           nclass=labels.max().item() + 1,
    #           dropout=0.5, device=device)

    # target_gcn = target_gcn.to(device)
    # target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
    target_gcn = surrogate

    SPADE_features, SPADE_adj = utils.to_tensor(features, adj, device=device)
    features1, adj1 = utils.to_tensor(features, adj, device=device)
    the_k = 50
    spec_embed = spectral_embedding_eig(orig_adj,orig_features,use_feature=True,adj_norm=False)
    neighs, distance = hnsw(spec_embed, k=the_k)
    embed_adj_mtx,_,_ = construct_weighted_adj(neighs, distance)#construct_weighted_adj,construct_adj
    embed_adj_mtx = SPF(embed_adj_mtx, 4)
    embed_adj_mtx1,_ = utils.to_tensor(embed_adj_mtx, adj, device=device)
    embed_out =  target_gcn(SPADE_features,embed_adj_mtx1)
    orig_out =  target_gcn(features1,adj1)
    TopEig, _, TopNodeList, _, L_in, L_out = spade_nonetworkx(embed_adj_mtx, embed_out.cpu().detach().numpy(), k=the_k)



    misclassified_ids=[]
    for idx in TopNodeList:
        margin = classification_margin(orig_out[idx], labels[idx])
        if margin < 0: 
            misclassified_ids.append(margin) 
    filtered_TopNodeList = [idx for idx in TopNodeList if idx not in misclassified_ids]
    spade_selected_node = filtered_TopNodeList[:40]

    '''''
    _, predicted_labels = torch.max(orig_out, 1)
    # Find the IDs of misclassified samples
    misclassified_ids = [idx for idx in TopNodeList if predicted_labels[idx] != labels[idx]]
    # Remove misclassified IDs from sample_ID_rank
    filtered_TopNodeList = [idx for idx in TopNodeList if idx not in misclassified_ids]
    spade_selected_node = filtered_TopNodeList[:40]
    '''''

    variance_rank = rank_samples_by_variance(orig_out.cpu().detach().numpy())
    _, predicted_labels = torch.max(orig_out, 1)
    misclassified_ids = [idx for idx in variance_rank if predicted_labels[idx] != labels[idx]]
    filtered_TopNodeList = [idx for idx in variance_rank if idx not in misclassified_ids]
    variance_selected_node = filtered_TopNodeList[:40]


    cnt = 0
    orig_cnt = 0
    degrees = adj.sum(0).A1
    node_list = select_nodes(target_gcn)
    num = len(node_list)
    


    print('=== [Evasion] Attacking %s nodes respectively ===' % num)
    for target_node in tqdm(node_list):
        n_perturbations = int(degrees[target_node])
        model = FGA(surrogate, nnodes=adj.shape[0], device=device)
        model = model.to(device)
        model.attack(features, adj, labels, idx_train, target_node, n_perturbations)
        modified_adj = model.modified_adj

        acc = single_test(modified_adj, features, target_node, gcn=target_gcn)
        orig_acc = single_test(adj, features, target_node, gcn=target_gcn)

        if acc == 0:
            cnt += 1
        if orig_acc == 0:
            orig_cnt += 1
    print('orignal mis rate : {}, perturbed mis rate : {}'.format(orig_cnt/num,cnt/num))

    cnt = 0
    orig_cnt = 0

    print('=== [Evasion] Attacking %s nodes respectively ===' % num)
    for target_node in tqdm(variance_selected_node.copy()):
        n_perturbations = int(degrees[target_node])
        model = FGA(surrogate, nnodes=adj.shape[0], device=device)
        model = model.to(device)
        model.attack(features, adj, labels, idx_train, target_node, n_perturbations)
        modified_adj = model.modified_adj

        acc = single_test(modified_adj, features, target_node, gcn=target_gcn)
        orig_acc = single_test(orig_adj, orig_features, target_node, gcn=target_gcn)
        if acc == 0:
            cnt += 1

        if orig_acc == 0:
            orig_cnt += 1

    print('variance orig mis rate : {}, perturbed mis rate : {}'.format(orig_cnt/num,cnt/num))

    cnt = 0
    orig_cnt = 0

    print('=== [Evasion] Attacking %s nodes respectively ===' % num)
    for target_node in tqdm(spade_selected_node.copy()):
        n_perturbations = int(degrees[target_node])
        model = FGA(surrogate, nnodes=adj.shape[0], device=device)
        model = model.to(device)
        model.attack(features, adj, labels, idx_train, target_node, n_perturbations)
        modified_adj = model.modified_adj

        acc = single_test(modified_adj, features, target_node, gcn=target_gcn)
        orig_acc = single_test(adj, features, target_node, gcn=target_gcn)
        if acc == 0:
            cnt += 1
        if orig_acc == 0:
            orig_cnt += 1
    print('SPADE orignal mis rate : {}, perturbed mis rate : {}'.format(orig_cnt/num,cnt/num))

if __name__ == '__main__':
    #main()
    multi_test_evasion()
    #multi_test_poison()
