import argparse
import pickle

import numpy as np
from scipy.sparse import load_npz
import torch
from deeprobust.graph.data import Dataset

from model import ada_filter, GCN
from logger import Logger
from utils import *
from my_utils.utils import spade,heterophily_score,featurePT,random_edgePT,hnsw,construct_adj, spectral_embedding_eig,SPF,construct_weighted_adj
import copy
import networkx as nx
from torch.nn.functional import cosine_similarity
import random
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from scipy.sparse.csgraph import laplacian
from torch.nn.functional import pairwise_distance
from julia.api import Julia

import warnings
warnings.filterwarnings("ignore")

import torch.nn.functional as F

def calculate_kl(tensor1, tensor2, mymask):
    tensor1 = tensor1[mymask]
    tensor2 = tensor2[mymask]
    kl_div = F.kl_div(tensor1, torch.exp(tensor2), reduction='batchmean')
    return kl_div.detach().numpy()


def calculate_cosine(tensor1,tensor2,mymask):
    cosine_sim = cosine_similarity(tensor1[mymask], tensor2[mymask], dim=0)
    aver_cosine_sim = torch.mean(cosine_sim)

    return aver_cosine_sim.detach().numpy()

def calculate_euclidean_distance(tensor1, tensor2, mymask):
    masked_tensor1 = tensor1[mymask]
    masked_tensor2 = tensor2[mymask]
    euclidean_dist = pairwise_distance(masked_tensor1, masked_tensor2)
    aver_euclidean_dist = torch.mean(euclidean_dist)

    return aver_euclidean_dist.detach().numpy()

def my_evaluate_select(x, edge_index, labels, mymask, model):
    model.eval()
    with torch.no_grad():
        logits = model(x, edge_index)
        _, indices = torch.max(logits[mymask], dim=1)
        correct = torch.sum(indices == labels[mymask])
        return correct.item() * 1.0 / len(labels[mymask]),logits

def my_evaluate_all(x, edge_index, labels, model):
    model.eval()
    with torch.no_grad():
        logits = model(x, edge_index)
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels),logits

def adj2graph(adj):
    G = nx.from_scipy_sparse_matrix(adj)
    return G

def julia_eig(l_in, num_eigs):
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include("./my_utils/eigen.jl")
    eigenvalues, eigenvectors = Main.not_main(l_in, num_eigs)
    return eigenvalues, eigenvectors

def julia_eig_plot(l_in, num_eigs):
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include("./my_utils/eigen.jl")
    eigenvalues, eigenvectors = Main.plot_main(l_in, num_eigs)
    return eigenvalues, eigenvectors

def plot_eig(L_mtx,n=100,name='mtx',label_number=None):
    plt.rcParams.update({'font.size': 15})
    eigenvalues, eigenvectors = julia_eig_plot(L_mtx, n)
    x = np.arange(1, len(eigenvalues) + 1)
    plt.plot(x, eigenvalues.real, marker='o')
    plt.xlabel('n-th Smallest Eigenvalue')
    plt.ylabel('Eigenvalue')
    #plt.title('{},class:{}'.format(name,label_number))
    plt.title('{}'.format(name))
    plt.grid(True)
    # Set y-axis ticks to integer values
    ax = plt.gca()
    ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
    ax.tick_params(labelsize=10)
    plt.savefig('./paper_plot/{}.png'.format(name), dpi=300)
    plt.clf()


def plot_cosine(x_left,x_right,y_left1,y_left2,y_right1, y_right2,dataset_name,model_name,robust_degrees,nonrobust_degrees,topeig=None):
    # Create subplots
    the_font_size = 15
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    # Left subplot
    ax1.plot(x_left, y_left1, label='SAGMAN Stable')
    ax1.plot(x_left, y_left2, label='SAGMAN Unstable')
    ax1.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax1.set_xlabel('Random EdgesPT', fontsize=the_font_size)
    ax1.set_ylabel('Cosine Similarities', fontsize=the_font_size)
    ax1.set_xticks(x_left)
    ax1.tick_params(labelsize=10)
    # Right subplot
    ax2.plot(x_right, y_right1, label='SAGMAN Stable')
    ax2.plot(x_right, y_right2, label='SAGMAN Unstable')
    ax2.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax2.set_xlabel('Random FeaturePT', fontsize=the_font_size)
    #ax2.set_ylabel('cosine_similarities', fontsize=the_font_size)
    ax2.set_xticks(x_right)
    ax2.tick_params(labelsize=10)
    #fig.suptitle("SPADE Score:{:0.2f}, Robust Node Degree:{:0.2f}, NonRobust Node Degree:{:0.2f}".format(topeig,robust_degrees,nonrobust_degrees), fontsize=the_font_size)
    # Show the plot
    ax1.legend(fontsize=the_font_size)
    ax2.legend(fontsize=the_font_size)
    plt.tight_layout()
    plt.savefig('./paper_plot/{}_{}_cos.png'.format(dataset_name,model_name), dpi=300)
    plt.clf()


def plot_ED(x_left,x_right,y_left1,y_left2,y_right1, y_right2,dataset_name,model_name,robust_degrees,nonrobust_degrees,topeig=None):
    # Create subplots
    the_font_size = 15
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    # Left subplot
    ax1.plot(x_left, y_left1, label='SAGMAN Stable')
    ax1.plot(x_left, y_left2, label='SAGMAN Unstable')
    ax1.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax1.set_xlabel('Random EdgesPT', fontsize=the_font_size)
    ax1.set_ylabel('Kullback-Leibler Divergence', fontsize=the_font_size)
    ax1.set_xticks(x_left)
    ax1.tick_params(labelsize=10)
    # Right subplot
    ax2.plot(x_right, y_right1, label='SAGMAN Stable')
    ax2.plot(x_right, y_right2, label='SAGMAN Unstable')
    ax2.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax2.set_xlabel('Random FeaturePT', fontsize=the_font_size)
    #ax2.set_ylabel('Kullback-Leibler divergence', fontsize=the_font_size)
    ax2.set_xticks(x_right)
    ax2.tick_params(labelsize=10)
    #fig.suptitle("SPADE Score:{:0.2f}, Robust Node Degree:{:0.2f}, NonRobust Node Degree:{:0.2f}".format(topeig,robust_degrees,nonrobust_degrees), fontsize=the_font_size)
    # Show the plot
    ax1.legend(fontsize=the_font_size)
    ax2.legend(fontsize=the_font_size)
    plt.tight_layout()
    plt.savefig('./paper_plot/{}_{}_KL.png'.format(dataset_name,model_name), dpi=300)
    plt.clf()


def main(args):
    seed()
    #device = f'cuda:{args.device}' if args.device > -1 else 'cpu'
    device = 'cpu'
    device = torch.device(device)

    ## load dataset
    if 1:
        dataset = args.dataset
        if dataset in ['chameleon', 'squirrel']:
            with open(f'data/{dataset}_data.pickle', 'rb') as handle:
                data = pickle.load(handle)
            features = data["features"]
            labels = data["labels"]
            idx_train = data["idx_train"]
            idx_val = data["idx_val"]
            idx_test = data["idx_test"]
            adj_mtx = load_npz(f'data/{dataset}.npz')
 
        else:
            data = Dataset(root='./data/', name=dataset, setting='prognn')
            labels = data.labels
            features = data.features.todense()
            idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
            adj_mtx = data.adj
    
        adj_mtx = adj_mtx.asfptype()
        orig_adj_mtx = copy.copy(adj_mtx)
        ## purify input graph via GARNET
        if 1:
            edge_index = SparseTensor.from_scipy(adj_mtx).float().to(device)
            labels = torch.as_tensor(labels, dtype=torch.long).to(device)
            x = torch.from_numpy(features).to(device)
            d = x.shape[1]
            c = labels.max().item() + 1

            ## choose backbone GNN model
            if 1:
                if args.backbone == "gprgnn":
                    model = ada_filter(d, args.hidden_dim, c, dropout=args.dropout, coe=args.c, P=args.p)
                elif args.backbone == "gcn":
                    model = GCN(d, args.hidden_dim, c, num_layers=3, dropout=args.dropout, use_bn=False, norm=True)
                else:
                    raise NotImplementedError
                model = model.to(device)
                #logger = Logger(args.runs, args)

                ## GNN training
                for run in range(1):
                    '''''
                    model.reset_parameters()
                    optimizer = torch.optim.Adam(model.parameters(),
                               lr=args.lr, weight_decay=args.weight_decay)
                    best_val = float('-inf')
                    for epoch in range(args.epochs):
                        loss = train(model, labels, x, edge_index, idx_train, optimizer)
                        #result = test(model, labels, x, edge_index, idx_train, idx_val, idx_test,idx_test)
                        #logger.add_result(run, result[:-1])
                    '''''
                    model = torch.load('./embedd_plot_fig/{}_{}.pt'.format(args.backbone,args.dataset))
                    model.eval()

                    #logger.print_statistics(run)
                    the_k = 50
                    if dataset in ['chameleon', 'squirrel']:
                        the_k = 100
                        
                    spec_embed = spectral_embedding_eig(orig_adj_mtx,features,use_feature=True,eig_julia=True)
                    #spec_embed = spectral_embedding(orig_adj_mtx,features,use_feature=True,adj_norm=True)
                    neighs, distance = hnsw(spec_embed, k=the_k)
                    embed_adj_mtx,_,_ = construct_weighted_adj(neighs, distance)#construct_adj, construct_weighted_adj
                    embed_adj_mtx = SPF(embed_adj_mtx, 4)
                    #embed_adj_mtx.data = np.ones_like(embed_adj_mtx.data)#weighted to unweighted
                    embed_edge_index = SparseTensor.from_scipy(embed_adj_mtx).float().to(device)

                    model.eval()
                    embed_out = model(x, embed_edge_index)#edge_index,embed_edge_index

                    orig_edge_index = SparseTensor.from_scipy(orig_adj_mtx).float().to(device)
                    orig_out = model(x, orig_edge_index)

                    TopEig, _, TopNodeList, _, L_in, L_out, Dxy, Uxy = spade(embed_adj_mtx, embed_out.cpu().detach().numpy(), k=the_k, num_eigs=2)
                    buffer_name = args.backbone+'_'+dataset
                    #plot_G_eig(Dxy, buffer_name)
                    if 1:
                        plot_in_name = dataset+'_'+'Embedded'
                        plot_orig_name = dataset+'_'+'Orig'
                        L_orig = laplacian(orig_adj_mtx, normed=False)
                        plot_eig(L_in,n=50,name=plot_in_name,label_number=c)
                        plot_eig(L_orig,n=50,name=plot_orig_name,label_number=c)

                    #hetero selection
                    orig_edge_index = SparseTensor.from_scipy(orig_adj_mtx).float().to(device)
                    #model.eval()
                    #orig_out = model(x, orig_edge_index)
                    heteroNodeList,_ = heterophily_score(orig_adj_mtx, orig_out)

                    #percent
                    node_percent = 0.01
                    #node ranking
                    idx_nonrobust = TopNodeList[:int(TopNodeList.shape[0]*node_percent)]
                    hetero_nonrobust = heteroNodeList[:int(TopNodeList.shape[0]*node_percent)]
                    idx_robust = TopNodeList[-int(TopNodeList.shape[0]*node_percent):]
                    hetero_robust = heteroNodeList[-int(TopNodeList.shape[0]*node_percent):]
                    # Randomly select m unique integers
                    idx_random = np.array(random.sample(list(range(TopNodeList.shape[0])), int(TopNodeList.shape[0]*node_percent)))
                    print('model:{}; dataset:{}; the TopEig: {:0.2f}'.format(args.backbone,args.dataset,TopEig))
                    
                G_orig = nx.from_scipy_sparse_matrix(orig_adj_mtx)
                # Calculate the average degree
                robust_degrees = np.array([G_orig.degree(node_id) for node_id in idx_robust])
                robust_degrees = np.mean(robust_degrees)

                nonrobust_degrees = np.array([G_orig.degree(node_id) for node_id in idx_nonrobust])
                nonrobust_degrees = np.mean(nonrobust_degrees)
                
                spade_nonrobust_acc, _ = my_evaluate_select(x, orig_edge_index, labels, idx_nonrobust.copy(), model)
                spade_robust_acc,_ =     my_evaluate_select(x, orig_edge_index, labels, idx_robust.copy(), model)
                hetero_nonrobust_acc,_ = my_evaluate_select(x, orig_edge_index, labels, hetero_nonrobust.copy(), model)
                hetero_robust_acc,_ =    my_evaluate_select(x, orig_edge_index, labels, hetero_robust.copy(), model)
                print('original ; SPADE_robust: {:0.2f}; hetero_robust: {:0.2f}; SPADE_nonrobust: {:0.2f}; hetero_nonrobust: {:0.2f}'.format(spade_robust_acc, hetero_robust_acc, spade_nonrobust_acc,hetero_nonrobust_acc))
                x_left=[]
                x_right=[]
                y_left1=[]
                y_left2=[]
                y_left3=[]
                y_left4=[]
                y_right1=[]
                y_right2=[]
                y_right3=[]
                y_right4=[]
                for i in range(5):
                    pt_adj = random_edgePT(orig_adj_mtx, i+1,labels,idx_robust)
                    pt_edge_index = SparseTensor.from_scipy(pt_adj).float().to(device)
                    robust_acc, spade_out_robust = my_evaluate_select(x, pt_edge_index, labels, idx_robust.copy(), model)

                    pt_adj = random_edgePT(orig_adj_mtx, i+1,labels,idx_nonrobust)
                    pt_edge_index = SparseTensor.from_scipy(pt_adj).float().to(device)
                    nonrobust_acc, spade_out_nonrobust = my_evaluate_select(x, pt_edge_index, labels, idx_nonrobust.copy(), model)

                    pt_adj = random_edgePT(orig_adj_mtx, i+1,labels,hetero_robust)
                    pt_edge_index = SparseTensor.from_scipy(pt_adj).float().to(device)
                    hetero_robust_acc, hetero_out_robust = my_evaluate_select(x, pt_edge_index, labels, hetero_robust.copy(), model)

                    pt_adj = random_edgePT(orig_adj_mtx, i+1,labels,hetero_nonrobust)
                    pt_edge_index = SparseTensor.from_scipy(pt_adj).float().to(device)
                    hetero_nonrobust_acc, hetero_out_nonrobust = my_evaluate_select(x, pt_edge_index, labels, hetero_nonrobust.copy(), model)

                    pt_adj = random_edgePT(orig_adj_mtx, i+1,labels,idx_random)
                    pt_edge_index = SparseTensor.from_scipy(pt_adj).float().to(device)
                    random_acc, random_out = my_evaluate_select(x, pt_edge_index, labels, idx_random.copy(), model)

                    print(' edge randomPT:{} ; SPADE_robust: {:0.2f}; hetero_robust: {:0.2f}; random: {:0.2f}; SPADE_nonrobust: {:0.2f}; hetero_nonrobust: {:0.2f}'.format(i+1, robust_acc, hetero_robust_acc, random_acc, nonrobust_acc,hetero_nonrobust_acc))
        
                    cos_spade_nonrobust = calculate_cosine(orig_out,spade_out_nonrobust,idx_nonrobust.copy())
                    cos_spade_robust = calculate_cosine(orig_out,spade_out_robust,idx_robust.copy())

                    ED_spade_nonrobust = calculate_kl(orig_out,spade_out_nonrobust,idx_nonrobust.copy())
                    ED_spade_robust = calculate_kl(orig_out,spade_out_robust,idx_robust.copy())

                    cos_hetero_nonrobust = calculate_cosine(orig_out,hetero_out_nonrobust,hetero_nonrobust.copy())
                    cos_hetero_robust = calculate_cosine(orig_out,hetero_out_robust,hetero_robust.copy())

                    cos_random = calculate_cosine(orig_out,random_out,idx_random.copy())

                    print('SPADE cosine_similarity: robust:{:0.2f}; nonrobust:{:0.2f}'.format(cos_spade_robust,cos_spade_nonrobust))
                    print('hetero cosine_similarity:robust:{:0.2f}; nonrobust:{:0.2f}'.format(cos_hetero_robust,cos_hetero_nonrobust))
                    print('random cosine_similarity: {:0.2f}'.format(cos_random))
                    x_left.append(i+1)
                    y_left1.append(cos_spade_robust)
                    y_left2.append(cos_spade_nonrobust)
                    y_left3.append(ED_spade_robust)
                    y_left4.append(ED_spade_nonrobust)
    
                for i in [0.4,0.8,1.2,1.6,2.0]:
                    PTfeature = featurePT(x,i)
                    model.eval()
                    PT_out = model(PTfeature, orig_edge_index)
                    robust_acc,_ = my_evaluate_select(PTfeature, orig_edge_index, labels, idx_robust.copy(), model)
                    nonrobust_acc,_ = my_evaluate_select(PTfeature, orig_edge_index, labels, idx_nonrobust.copy(), model)
                    random_acc,_ = my_evaluate_select(PTfeature, orig_edge_index, labels, idx_random.copy(), model)
                    hetero_robust_acc,_ = my_evaluate_select(PTfeature, orig_edge_index, labels, hetero_robust.copy(), model)
                    hetero_nonrobust_acc,_ = my_evaluate_select(PTfeature, orig_edge_index, labels, hetero_nonrobust.copy(), model)

                    cos_spade_nonrobust = calculate_cosine(orig_out,PT_out,idx_nonrobust.copy())
                    cos_spade_robust = calculate_cosine(orig_out,PT_out,idx_robust.copy())

                    ED_spade_nonrobust = calculate_kl(orig_out,PT_out,idx_nonrobust.copy())
                    ED_spade_robust = calculate_kl(orig_out,PT_out,idx_robust.copy())

                    cos_hetero_nonrobust = calculate_cosine(orig_out,PT_out,hetero_nonrobust.copy())
                    cos_hetero_robust = calculate_cosine(orig_out,PT_out,hetero_robust.copy())

                    cos_random = calculate_cosine(orig_out,PT_out,idx_random.copy())

                    print(' feat randomPT:{} ; SPADE_robust: {:0.2f}; hetero_robust: {:0.2f}; random: {:0.2f}; SPADE_nonrobust: {:0.2f}; hetero_nonrobust: {:0.2f}'.format(i, robust_acc, hetero_robust_acc, random_acc, nonrobust_acc,hetero_nonrobust_acc))
                    print('SPADE cosine_similarity: robust:{:0.2f}; nonrobust:{:0.2f}'.format(cos_spade_robust,cos_spade_nonrobust))
                    print('hetero cosine_similarity:robust:{:0.2f}; nonrobust:{:0.2f}'.format(cos_hetero_robust,cos_hetero_nonrobust))
                    print('random cosine_similarity: {:0.2f}'.format(cos_random))
                    x_right.append(i)
                    y_right1.append(cos_spade_robust)
                    y_right2.append(cos_spade_nonrobust)
                    y_right3.append(ED_spade_robust)
                    y_right4.append(ED_spade_nonrobust)


                plot_cosine(x_left,x_right,y_left1,y_left2,y_right1, y_right2,dataset,args.backbone,robust_degrees,nonrobust_degrees,topeig=TopEig)
                plot_ED(x_left,x_right,y_left3,y_left4,y_right3, y_right4,dataset,args.backbone,robust_degrees,nonrobust_degrees,topeig=TopEig)


    
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=int, default=0,
                        help='choose GPU device id')
    parser.add_argument('--dataset', type=str, default='cora',
                        choices=['cora', 'citeseer', 'pubmed', 'chameleon', 'squirrel'],
                        help='choose graph dataset')
    parser.add_argument('--perturbed', action='store_true',
                        help='use adversarial graph as input')
    parser.add_argument('--attack', type=str, default='meta',
                        choices=['nettack', 'meta', 'grbcd'],
                        help='used to choose attack method and test nodes')
    parser.add_argument('--backbone', type=str, default='gcn',
                        choices=['gcn', 'gprgnn'],
                        help='backbone GNN model')
    parser.add_argument('--ptb_rate', type=float, default=.2,
                        help='adversarial perturbation budget:\
                        suggest to use 0.2 for meta attack, 5.0 for nettack attack, 0.5 for grbcd attack')
    parser.add_argument('--runs', type=int, default=10,
                        help='how many runs to compute accuracy mean and std')
    parser.add_argument('--display_step', type=int,
                        default=1, help='how often to print')
    parser.add_argument('--full_distortion', action='store_true',
                        help='Use the non-simplified spectral embedding distortion')
    parser.add_argument('--no_garnet', action='store_true',
                        help='No using GARNET to purify input graph')
    parser.add_argument('--random', action='store_true',
                        help='using random edge perturbation')
    parser.add_argument('--k', type=int,
                        help='k for kNN graph construction')
    parser.add_argument('--kk', type=int,
                        help='k for kNN graph construction')
    parser.add_argument('--weighted_knn', type=str,
                        choices=[None, 'True', 'False'],
                        help='use weighted knn graph')
    parser.add_argument('--adj_norm', type=str,
                        choices=[None, 'True', 'False'],
                        help='normalize adjacency matrix')
    parser.add_argument('--use_feature', type=str,
                        choices=[None, 'True', 'False'],
                        help='incorporate node features for kNN construction')
    parser.add_argument('--embedding_norm', type=str,
                        choices=[None, 'unit_vector', 'standardize', 'minmax'],
                        help='normalize node embeddings for kNN construction')
    parser.add_argument('--gamma', type=float,
                        help='threshold to sparsify kNN graph')
    parser.add_argument('--r', type=int, help='number of eigenpairs')
    parser.add_argument('--epochs', type=int, help='number of epochs')
    parser.add_argument('--lr', type=float, help='learning rate')
    parser.add_argument('--dropout', type=float, help='dropout rate')
    parser.add_argument('--hidden_dim', type=int, help='GNN hidden dimension')
    parser.add_argument('--weight_decay', type=float, help='weight decay for GNN training')
    parser.add_argument('--p', type=int,
                        help='adaptive filter degree in GPRGNN')
    parser.add_argument('--c', type=float,
                        help='coefficients of adaptive filter in GPRGNN')

    ## combine input arguments w/ arguments in configuration files
    args = parser.parse_args()
    args = preprocess_args(args)
    print(args)
    main(args)
