import torch
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from deeprobust.graph.defense import GCN, GAT,ChebNet
from deeprobust.graph.targeted_attack import Nettack
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset, Dpr2Pyg
import argparse
from tqdm import tqdm
import warnings
import copy
from deeprobust.graph import utils
from ogb.nodeproppred import NodePropPredDataset
from my_utils.utils import spade,hnsw,construct_adj, spectral_embedding_eig,SPF,construct_weighted_adj,spade_nonetworkx
from scipy.sparse import load_npz
from torch_sparse import SparseTensor
from scipy.sparse import csr_matrix
import pickle

# Suppress all warnings
warnings.filterwarnings("ignore")


def convert_csr_to_float64(matrix):

    # Convert the data array to float64
    float64_data = matrix.data.astype(np.float64)
    # Create a new CSR matrix with the float64 data
    converted_matrix = csr_matrix((float64_data, matrix.indices, matrix.indptr), shape=matrix.shape)

    return converted_matrix


def to_unweighted_csr(adj_matrix_csr):
    # Get the indices and indptr from the original matrix
    indices = adj_matrix_csr.indices
    indptr = adj_matrix_csr.indptr
    # Create an array of 1s for the values (all edges have weight 1)
    unweighted_data = np.ones(len(indices), dtype=np.float64)
    # Create the unweighted adjacency matrix in CSR format
    unweighted_adj_matrix = csr_matrix((unweighted_data, indices, indptr), shape=adj_matrix_csr.shape)

    return unweighted_adj_matrix


def add_edges_from_top_nodes(adj, embed_adj_mtx, TopNodeList, percentages):
    # Ensure matrices are in LIL format for easier value setting
    adj_lil = adj.tolil()
    embed_adj_mtx_lil = embed_adj_mtx.tolil()
    # Calculate node counts for each percentage
    node_count = int(len(TopNodeList) * percentages)
    # For each percentage
    if 1:
        # Get the top nodes based on the percentage
        top_nodes = TopNodeList[:node_count]
        # For each node in the top nodes
        for node in top_nodes:
            # Get the rows corresponding to the current node
            embed_row = embed_adj_mtx_lil.rows[node]
            adj_row = adj_lil.rows[node]
            # Identify edges that are in 'embed_adj_mtx' but not in 'adj'
            new_edges = set(embed_row) - set(adj_row)
            # Add the new edges to 'adj'
            for new_edge in new_edges:
                if 1:
                    adj_lil[node, new_edge] = 1  # Assuming an unweighted, undirected graph
    # Convert 'adj' back to CSR format
    adj_new = adj_lil.tocsr()
    #adj_new = adj + embed_adj_mtx
    return adj_new

def normal_adj(adj):
    adj = SparseTensor.from_scipy(adj)
    deg = adj.sum(dim=1).to(torch.float)
    D_isqrt = deg.pow(-0.5)
    D_isqrt[D_isqrt == float('inf')] = 0
    DAD = D_isqrt.view(-1,1) * adj * D_isqrt.view(1,-1)

    return DAD.to_scipy(layout='csr')

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--dataset', type=str, default='citeseer', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed','arxiv','products','chameleon', 'squirrel'], help='dataset')


args = parser.parse_args()
args.cuda = torch.cuda.is_available()
print('cuda: %s' % args.cuda)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

if args.dataset in ['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed']:
    data = Dataset(root='./tmp/', name=args.dataset)
    adj, features, labels = data.adj, data.features, data.labels
    orig_adj, orig_features = copy.copy(adj), copy.copy(features.todense())
    idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test

elif args.dataset in ['chameleon', 'squirrel']:
    
    with open(f'./tmp/{args.dataset}_data.pickle', 'rb') as handle:
            data = pickle.load(handle)
    features = data["features"]
    labels = data["labels"]
    idx_train = data["idx_train"].astype(np.int32)
    idx_val = data["idx_val"].astype(np.int32)
    idx_test = data["idx_test"].astype(np.int32)
    adj = load_npz(f'./tmp/{args.dataset}.npz')

    orig_adj, orig_features = copy.copy(adj), copy.copy(features)
    features = csr_matrix(features)


else:
    ogbn_dataset = NodePropPredDataset(name=f'ogbn-{args.dataset}', root='../../../ogbn/ogbn_data/')
    ogbn_data = ogbn_dataset[0]
    split_idx = ogbn_dataset.get_idx_split()
    idx_train = split_idx['train'].astype(np.int32)
    idx_val = split_idx['valid'].astype(np.int32)
    idx_test = split_idx['test'].astype(np.int32)
    features = ogbn_data[0]['node_feat']
    labels = ogbn_data[1].reshape(-1,)

    adj = load_npz(f'../../../ogbn/data/{args.dataset}/{args.dataset}_clean.npz')

    if args.dataset == "products":
        adj = normal_adj(adj)
        args.epochs = 300

    orig_adj, orig_features = copy.copy(adj), copy.copy(features)
    features = csr_matrix(features)

idx_unlabeled = np.union1d(idx_val, idx_test)

# Setup Surrogate model
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16, dropout=0, with_relu=False, with_bias=False, device=device)

surrogate = surrogate.to(device)
surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)

# Setup Attack Model
#target_node = 0
#assert target_node in idx_unlabeled

model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
model = model.to(device)

def main():
    degrees = adj.sum(0).A1
    # How many perturbations to perform. Default: Degree of the node
    n_perturbations = int(degrees[target_node])

    # direct attack
    model.attack(features, adj, labels, target_node, n_perturbations)
    # # indirect attack/ influencer attack
    # model.attack(features, adj, labels, target_node, n_perturbations, direct=False, n_influencers=5)
    modified_adj = model.modified_adj
    modified_features = model.modified_features
    print(model.structure_perturbations)
    print('=== testing GCN on original(clean) graph ===')
    test(adj, features, target_node)
    print('=== testing GCN on perturbed graph ===')
    test(modified_adj, modified_features, target_node)

def test(adj, features, target_node):
    ''' test on GCN '''
    gcn = GCN(nfeat=features.shape[1],
              nhid=16,
              nclass=labels.max().item() + 1,
              dropout=0.5, device=device)

    gcn = gcn.to(device)

    gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)

    gcn.eval()
    output = gcn.predict()
    probs = torch.exp(output[[target_node]])[0]
    print('Target node probs: {}'.format(probs.detach().cpu().numpy()))
    acc_test = accuracy(output[idx_test], labels[idx_test])

    print("Overall test set results:",
          "accuracy= {:.4f}".format(acc_test.item()))

    return acc_test.item()

def select_nodes(target_gcn=None):
    '''
    selecting nodes as reported in nettack paper:
    (i) the 10 nodes with highest margin of classification, i.e. they are clearly correctly classified,
    (ii) the 10 nodes with lowest margin (but still correctly classified) and
    (iii) 20 more nodes randomly
    '''

    if target_gcn is None:
        target_gcn = GCN(nfeat=features.shape[1],
                  nhid=16,
                  nclass=labels.max().item() + 1,
                  dropout=0.5, device=device)
        target_gcn = target_gcn.to(device)
        target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
    target_gcn.eval()
    output = target_gcn.predict()

    margin_dict = {}
    for idx in idx_test:
        margin = classification_margin(output[idx], labels[idx])
        if margin < 0: # only keep the nodes correctly classified
            continue
        margin_dict[idx] = margin
    sorted_margins = sorted(margin_dict.items(), key=lambda x:x[1], reverse=True)
    high = [x for x, y in sorted_margins[: 10]]
    low = [x for x, y in sorted_margins[-10: ]]
    other = [x for x, y in sorted_margins[10: -10]]
    other = np.random.choice(other, 20, replace=False).tolist()

    return high + low + other


def random_select_nodes(target_gcn=None):
    target_gcn.eval()
    output = target_gcn.predict()
    margin_dict = {}
    for idx in idx_test:
        margin = classification_margin(output[idx], labels[idx])
        if margin < 0: # only keep the nodes correctly classified
            continue
        margin_dict[idx] = margin
    sorted_margins = sorted(margin_dict.items(), key=lambda x:x[1], reverse=True)
    other = [x for x, y in sorted_margins]
    other = np.random.choice(other, 40, replace=False).tolist()

    return other



def multi_test_poison():
    # test on 40 nodes on poisoining attack
    cnt = 0
    degrees = adj.sum(0).A1
    node_list = select_nodes()
    num = len(node_list)
    print('=== [Poisoning] Attacking %s nodes respectively ===' % num)
    for target_node in tqdm(node_list):
        n_perturbations = int(degrees[target_node])
        model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=True, device=device)
        model = model.to(device)
        model.attack(features, adj, labels, target_node, n_perturbations, verbose=False)
        modified_adj = model.modified_adj
        modified_features = model.modified_features
        acc = single_test(modified_adj, modified_features, target_node)
        if acc == 0:
            cnt += 1
    print('misclassification rate : %s' % (cnt/num))

def single_test(adj, features, target_node, gcn=None):
    if gcn is None:
        # test on GCN (poisoning attack)
        gcn = GCN(nfeat=features.shape[1],
                  nhid=16,
                  nclass=labels.max().item() + 1,
                  dropout=0.5, device=device)

        gcn = gcn.to(device)

        gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
        gcn.eval()
        output = gcn.predict()
    else:
        gcn.eval()
        # test on GCN (evasion attack)
        output = gcn.predict(features, adj)
    probs = torch.exp(output[[target_node]])

    # acc_test = accuracy(output[[target_node]], labels[target_node])
    acc_test = (output.argmax(1)[target_node] == labels[target_node])
    return acc_test.item()

def multi_test_evasion():
    # test on 40 nodes on evasion attack
    target_gcn = GCN(nfeat=features.shape[1],
              nhid=16,
              nclass=labels.max().item() + 1,
              dropout=0.5, device=device)

    target_gcn = target_gcn.to(device)

    target_gcn.fit(features, adj, labels, idx_train, idx_val, patience=30)
    

    target_gcn.eval()
    SPADE_features, SPADE_adj = utils.to_tensor(features, adj, device=device)
    features1, adj1 = utils.to_tensor(features, adj, device=device)
    the_k = 50
    spec_embed = spectral_embedding_eig(orig_adj,orig_features,use_feature=True,adj_norm=False)#spectral_embedding_eig,spectral_embedding
    neighs, distance = hnsw(spec_embed, k=the_k)
    embed_adj_mtx = construct_weighted_adj(neighs, distance)#construct_weighted_adj,construct_adj
    embed_adj_mtx, inter_edge_adj = SPF(embed_adj_mtx, 4)
    embed_adj_mtx = to_unweighted_csr(embed_adj_mtx)
    _, embed_adj_mtx1 = utils.to_tensor(features, embed_adj_mtx, device=device)
    embed_out =  target_gcn(SPADE_features,embed_adj_mtx1)
    orig_out =  target_gcn(features1,adj1)

    #embed_acc = accuracy(embed_out, labels)
    orig_acc_test = accuracy(orig_out, labels)

    print('original acc: {}'.format(orig_acc_test))
    #print('embed acc: {}'.format(embed_acc))

    TopEig, _, TopNodeList, _, L_in, L_out = spade_nonetworkx(embed_adj_mtx, embed_out.cpu().detach().numpy(), k=the_k)#spade

    _, predicted_labels = torch.max(orig_out, 1)
    # Find the IDs of misclassified samples
    misclassified_ids = [idx for idx in TopNodeList if predicted_labels[idx] != labels[idx]]
    # Remove misclassified IDs from sample_ID_rank
    filtered_TopNodeList = [idx for idx in TopNodeList if idx not in misclassified_ids]

    for i in [0.01, 0.05, 0.1,0.15,0.2]:
        select_num = int(len(TopNodeList) * i)
        spade_selected_node = filtered_TopNodeList[:select_num]

        adj_part_embedd = add_edges_from_top_nodes(adj, inter_edge_adj, TopNodeList, i)#inter_edge_adj,embed_adj_mtx
        _, adj_part_embedd1 = utils.to_tensor(features, adj_part_embedd, device=device)
        embed_out =  target_gcn(SPADE_features,adj_part_embedd1)
        part_embedd_acc_test = accuracy(embed_out, labels)
        print('{} node added embedded edges acc: {}'.format(i, part_embedd_acc_test))

        cnt = 0
        orig_cnt = 0
        degrees = adj.sum(0).A1
        #node_list = random_select_nodes(target_gcn)
        node_list = spade_selected_node
        target_gcn.eval()
        output = target_gcn.predict()
    
        print('=== [Evasion] Attacking %s nodes respectively ===' % select_num)
        for target_node in tqdm(node_list):
            n_perturbations = int(degrees[target_node])
            model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=False, device=device)
            model = model.to(device)
            model.attack(features, adj_part_embedd, labels, target_node, n_perturbations, verbose=False)#adj,embed_adj_mtx
            modified_adj = model.modified_adj
            modified_features = model.modified_features

            acc = single_test(modified_adj, modified_features, target_node, gcn=target_gcn)
            orig_acc = single_test(adj_part_embedd, features, target_node, gcn=target_gcn)
            if acc == 0:
                cnt += 1
            if orig_acc == 0:
                orig_cnt += 1
        print('{} embedded mis rate : {}, perturbed mis rate : {}'.format(i, orig_cnt/select_num,cnt/select_num))

        cnt = 0
        orig_cnt = 0

        degrees = adj.sum(0).A1

        print('=== [Evasion] Attacking %s nodes respectively ===' % select_num)
        for target_node in tqdm(node_list):
            n_perturbations = int(degrees[target_node])
            model = Nettack(surrogate, nnodes=adj.shape[0], attack_structure=True, attack_features=False, device=device)
            model = model.to(device)
            model.attack(features, adj, labels, target_node, n_perturbations, verbose=False)
            modified_adj = model.modified_adj
            modified_features = model.modified_features

            acc = single_test(modified_adj, modified_features, target_node, gcn=target_gcn)
            orig_acc = single_test(adj, features, target_node, gcn=target_gcn)

            if acc == 0:
                cnt += 1
            if orig_acc == 0:
                orig_cnt += 1

        print('orignal mis rate : {}, perturbed mis rate : {}'.format(orig_cnt/select_num,cnt/select_num))


if __name__ == '__main__':
    #main()
    #multi_test_poison()
    multi_test_evasion()


