import argparse
from dataset_loader import DataLoader
from utils import random_planetoid_splits
from models import *
import torch
import torch.nn.functional as F
from tqdm import tqdm
import random
import seaborn as sns
import numpy as np
import time
import copy
import networkx as nx
from torch_sparse import SparseTensor
from scipy.sparse import coo_matrix
from my_utils.utils import spade, spectral_embedding, hnsw, construct_adj, heterophily_score,spectral_embedding_eig,SPF,construct_weighted_adj
from torch.nn.functional import cosine_similarity

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from scipy.sparse.csgraph import laplacian
from torch.nn.functional import pairwise_distance
from julia.api import Julia
from deeprobust.graph.data import Dataset
import pickle
from scipy.sparse import load_npz
#import pdb

import warnings
warnings.filterwarnings("ignore")


def create_mask(indices, tensor_size):
    # create an empty ByteTensor with False values
    mask = torch.zeros(tensor_size, dtype=torch.bool)
    # update the positions in your indices to True
    mask[indices] = True
    return mask

def julia_eig(l_in, num_eigs):
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include("./my_utils/eigen.jl")
    eigenvalues, eigenvectors = Main.not_main(l_in, num_eigs)
    return eigenvalues, eigenvectors

def julia_eig_plot(l_in, num_eigs):
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include("./my_utils/eigen.jl")
    eigenvalues, eigenvectors = Main.plot_main(l_in, num_eigs)
    return eigenvalues, eigenvectors

def calculate_kl(tensor1, tensor2, mymask):
    tensor1 = tensor1[mymask]
    tensor2 = tensor2[mymask]
    kl_div = F.kl_div(tensor1, torch.exp(tensor2), reduction='batchmean')
    return kl_div.cpu().detach().numpy()

def calculate_cosine(tensor1,tensor2,mymask):
    cosine_sim = cosine_similarity(tensor1[mymask], tensor2[mymask], dim=0)
    aver_cosine_sim = torch.mean(cosine_sim)

    return aver_cosine_sim.cpu().detach().numpy()

def calculate_euclidean_distance(tensor1, tensor2, mymask):
    masked_tensor1 = tensor1[mymask]
    masked_tensor2 = tensor2[mymask]
    euclidean_dist = pairwise_distance(masked_tensor1, masked_tensor2)
    aver_euclidean_dist = torch.mean(euclidean_dist)

    return aver_euclidean_dist.cpu().detach().numpy()

def jaccard_similarity(arr1, arr2):
    set1 = set(arr1)
    set2 = set(arr2)
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union > 0 else 0

def adj2graph(adj):
    G = nx.from_scipy_sparse_matrix(adj)
    return G

def my_evaluate_select(data, mymask, model):
    model.eval()
    with torch.no_grad():
        logits = model(data)
        _, indices = torch.max(logits[mymask], dim=1)
        correct = torch.sum(indices == data.y[mymask])
        return correct.item() * 1.0 / len(data.y[mymask]),logits

def edge2adj(edge_index):
    edge_index = edge_index.cpu()
    # Find the number of nodes (assuming the node indices are 0-based)
    num_nodes = edge_index.max().item() + 1
    # Convert the edge index tensor to a NumPy array
    edge_index_np = edge_index.numpy()
    # Create the sparse adjacency matrix using scipy.sparse.coo_matrix
    coo_adj_matrix = coo_matrix((torch.ones(edge_index.shape[1]), (edge_index_np[0], edge_index_np[1])), shape=(num_nodes, num_nodes))
    # Convert the COO matrix to CSR format
    csr_adj_matrix = coo_adj_matrix.tocsr()
    return csr_adj_matrix

def adj2edge(adj):
    # Convert the CSR matrix to COO format
    coo_adj_matrix = adj.tocoo()
    # Get the row (source nodes) and col (target nodes) attributes of the COO matrix
    source_nodes = coo_adj_matrix.row
    target_nodes = coo_adj_matrix.col
    # Combine source and target nodes as edge index
    edge_index_np = np.vstack((source_nodes, target_nodes))
    # Convert the NumPy array back to a PyTorch tensor
    edge_index = torch.tensor(edge_index_np, dtype=torch.long)
    return edge_index

def featurePT(x,beta):
    samples_num,dimen = x.shape
    std_dev = torch.std(x)
    noise = torch.randn(samples_num, dimen) * std_dev
    x += noise * beta
    return x

def random_edgePT(graph, the_p,label,node_index):
    perturbed_graph = graph.copy()
    for node in node_index:
        other_nodes = [n for n in perturbed_graph.nodes() if n != node and label[n] != label[node]]
        nodes_to_connect = random.sample(other_nodes, the_p)
        edges = [(node, n) for n in nodes_to_connect]
        perturbed_graph.add_edges_from(edges)
        # randomly select x edges to remove between nodes with the same label as node
        same_label_neighbors = [n for n in perturbed_graph.neighbors(node) if label[n] == label[node]]
        edges_to_remove = random.sample(same_label_neighbors, min(the_p, len(same_label_neighbors)))
        # remove the selected edges
        perturbed_graph.remove_edges_from([(node, n) for n in edges_to_remove])

    return perturbed_graph

def plot_eig(L_mtx,n=100,name='mtx',label_number=None):
    plt.rcParams.update({'font.size': 15})
    eigenvalues, eigenvectors = julia_eig(L_mtx, n)
    x = np.arange(1, len(eigenvalues) + 1)
    plt.plot(x, eigenvalues.real, marker='o')
    plt.xlabel('nth Smallest Eigenvalue')
    plt.ylabel('Eigenvalue')
    #plt.title('{},class:{}'.format(name,label_number))
    plt.title('{}'.format(name))
    plt.grid(True)
    # Set y-axis ticks to integer values
    ax = plt.gca()
    ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
    ax.tick_params(labelsize=10)
    plt.savefig('./paper_plot/{}.png'.format(name), dpi=300)
    plt.clf()

def plot_cosine(x_left,x_right,y_left1,y_left2,y_right1, y_right2,dataset_name,model_name,robust_degrees,nonrobust_degrees,topeig=None):
    # Create subplots
    the_font_size = 15
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    # Left subplot
    ax1.plot(x_left, y_left1, label='SAGMAN Stable')
    ax1.plot(x_left, y_left2, label='SAGMAN UnStable')
    ax1.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax1.set_xlabel('Random EdgesPT', fontsize=the_font_size)
    ax1.set_ylabel('cosine_similarities', fontsize=the_font_size)
    ax1.set_xticks(x_left)
    ax1.tick_params(labelsize=10)
    # Right subplot
    ax2.plot(x_right, y_right1, label='SAGMAN Stable')
    ax2.plot(x_right, y_right2, label='SAGMAN UnStable')
    ax2.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax2.set_xlabel('Random FeaturePT', fontsize=the_font_size)
    #ax2.set_ylabel('cosine_similarities')
    ax2.set_xticks(x_right)
    ax2.tick_params(labelsize=10)
    fig.suptitle("SPADE Score:{:0.2f}, Robust Node Degree:{:0.2f}, NonRobust Node Degree:{:0.2f}".format(topeig,robust_degrees,nonrobust_degrees), fontsize=the_font_size)
    # Show the plot
    ax1.legend(fontsize=the_font_size)
    ax2.legend(fontsize=the_font_size)
    plt.tight_layout()
    plt.savefig('./paper_plot/{}_{}_cos.png'.format(dataset_name,model_name), dpi=300)
    plt.clf()

def plot_ED(x_left,x_right,y_left1,y_left2,y_right1, y_right2,dataset_name,model_name,robust_degrees,nonrobust_degrees,topeig=None):
    # Create subplots
    the_font_size = 15
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    # Left subplot
    ax1.plot(x_left, y_left1, label='SAGMAN Stable')
    ax1.plot(x_left, y_left2, label='SAGMAN UnStable')
    ax1.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax1.set_xlabel('Random EdgesPT', fontsize=the_font_size)
    ax1.set_ylabel('Kullback-Leibler Divergence', fontsize=the_font_size)
    ax1.set_xticks(x_left)
    ax1.tick_params(labelsize=10)
    # Right subplot
    ax2.plot(x_right, y_right1, label='SAGMAN Stable')
    ax2.plot(x_right, y_right2, label='SAGMAN UnStable')
    ax2.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax2.set_xlabel('Random FeaturePT', fontsize=the_font_size)
    #ax2.set_ylabel('Euclidean Distance', fontsize=the_font_size)
    ax2.set_xticks(x_right)
    ax2.tick_params(labelsize=10)
    fig.suptitle("SPADE Score:{:0.2f}, Robust Node Degree:{:0.2f}, NonRobust Node Degree:{:0.2f}".format(topeig,robust_degrees,nonrobust_degrees), fontsize=the_font_size)
    # Show the plot
    ax1.legend(fontsize=the_font_size)
    ax2.legend(fontsize=the_font_size)
    plt.tight_layout()
    plt.savefig('./paper_plot/{}_{}_KL.png'.format(dataset_name,model_name), dpi=300)
    plt.clf()


def RunExp(args, data, Net, percls_trn, val_lb,The_spade_number):

    def train(model, optimizer, data, dprate):
        model.train()
        optimizer.zero_grad()
        out = model(data)[data.train_mask]
        nll = F.nll_loss(out, data.y[data.train_mask])
        loss = nll
        reg_loss=None
        loss.backward()
        optimizer.step()
        del out

    def test(model, data):
        model.eval()
        logits, accs, losses, preds = model(data), [], [], []
        for _, mask in data('train_mask', 'val_mask', 'test_mask'):
            pred = logits[mask].max(1)[1]
            acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()

            loss = F.nll_loss(model(data)[mask], data.y[mask])
            preds.append(pred.detach().cpu())
            accs.append(acc)
            losses.append(loss.detach().cpu())
        return accs, preds, losses, logits

    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')
    tmp_net = Net(data, args)

    #randomly split dataset
    permute_masks = random_planetoid_splits
    data = permute_masks(data, (data.y.max().item() + 1), percls_trn, val_lb,args.seed)

    model, data = tmp_net.to(device), data.to(device)

    if args.net=='GPRGNN':
        optimizer = torch.optim.Adam([{ 'params': model.lin1.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
        {'params': model.lin2.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
        {'params': model.prop1.parameters(), 'weight_decay': 0.00, 'lr': args.lr}])

    elif args.net =='BernNet':
        optimizer = torch.optim.Adam([{'params': model.lin1.parameters(),'weight_decay': args.weight_decay, 'lr': args.lr},
        {'params': model.lin2.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr},
        {'params': model.prop1.parameters(), 'weight_decay': 0.0, 'lr': args.Bern_lr}])
    else:
        optimizer = torch.optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)

    best_val_acc = test_acc = 0
    best_val_loss = float('inf')
    val_loss_history = []
    val_acc_history = []

    time_run=[]

    for epoch in range(args.epochs):
        t_st=time.time()
        train(model, optimizer, data, args.dprate)
        time_epoch=time.time()-t_st  # each epoch train times
        time_run.append(time_epoch)

        [train_acc, val_acc, tmp_test_acc], preds, [
            train_loss, val_loss, tmp_test_loss], _ = test(model, data)

        if val_loss < best_val_loss:
            best_val_acc = val_acc
            best_val_loss = val_loss
            test_acc = tmp_test_acc
            if args.net =='BernNet':
                TEST = tmp_net.prop1.temp.clone()
                theta = TEST.detach().cpu()
                theta = torch.relu(theta).numpy()
            else:
                theta = args.alpha

        if epoch >= 0:
            val_loss_history.append(val_loss)
            val_acc_history.append(val_acc)
            if args.early_stopping > 0 and epoch > args.early_stopping:
                tmp = torch.tensor(
                    val_loss_history[-(args.early_stopping + 1):-1])
                if val_loss > tmp.mean().item():
                    print('The sum of epochs:',epoch)
                    break

    #model = torch.load('./embedded_plot_fig/{}_{}.pt'.format(args.net,args.dataset))
    #model.eval()
    
    the_k = 50
    if args.dataset in ['Chameleon','Squirrel']:
        the_k = 100            

    adj_mtx = edge2adj(data.edge_index)
    orig_adj_mtx = copy.copy(adj_mtx)
    spec_embed = spectral_embedding_eig(adj_mtx,data.x.cpu(),use_feature=True,eig_julia=True)
    neighs, distance = hnsw(spec_embed, k=the_k)
    embed_adj_mtx,_,_ = construct_weighted_adj(neighs, distance)
    embed_adj_mtx = SPF(embed_adj_mtx, 4)
    #embed_adj_mtx.data = np.ones_like(embed_adj_mtx.data)#weighted to unweighted
    embed_edege_index = adj2edge(embed_adj_mtx)
    data_copy = copy.copy(data)

    data_copy.edge_index = embed_edege_index
    _, _, _, embed_out = test(model, data_copy.to(device))
    _, _, _, orig_out = test(model, data.to(device))
    TopEig, _, TopNodeList, _, L_in, L_out= spade(embed_adj_mtx, embed_out.cpu().detach().numpy(), the_k,num_eigs=The_spade_number)#args.embedding_norm) 

    heteroNodeList,_ = heterophily_score(orig_adj_mtx, orig_out)

    if args.net == "GCN":
        plot_in_name = args.dataset+'_'+'embedded'
        plot_orig_name = args.dataset+'_'+'orig'
        L_orig = laplacian(orig_adj_mtx, normed=True)
        c = data.y.max().item() + 1
        plot_eig(L_in,n=50,name=plot_in_name,label_number=c)
        plot_eig(L_orig,n=50,name=plot_orig_name,label_number=c)

    theta = 0

    return test_acc, best_val_acc, theta, time_run, TopEig, TopNodeList, model, heteroNodeList


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=2108550661, help='seeds for random splits.')
    parser.add_argument('--epochs', type=int, default=1000, help='max epochs.')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate.')       
    parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay.')  
    parser.add_argument('--early_stopping', type=int, default=200, help='early stopping.')
    parser.add_argument('--hidden', type=int, default=64, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.0, help='dropout for neural networks.')

    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')
    parser.add_argument('--K', type=int, default=10, help='propagation steps.')
    parser.add_argument('--alpha', type=float, default=0.1, help='alpha for APPN/GPRGNN.')
    parser.add_argument('--dprate', type=float, default=0.5, help='dropout for propagation layer.')
    parser.add_argument('--Init', type=str,choices=['SGC', 'PPR', 'NPPR', 'Random', 'WS', 'Null'], default='PPR', help='initialization for GPRGNN.')
    parser.add_argument('--heads', default=8, type=int, help='attention heads for GAT.')
    parser.add_argument('--output_heads', default=1, type=int, help='output_heads for GAT.')

    parser.add_argument('--dataset', type=str, choices=['Cora','Citeseer','Pubmed','Computers','Photo','Chameleon','Squirrel','Actor','Texas','Cornell'],
                        default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, choices=['GCN', 'GAT', 'APPNP', 'ChebNet', 'GPRGNN','BernNet','MLP'], default='BernNet')
    parser.add_argument('--Bern_lr', type=float, default=0.01, help='learning rate for BernNet propagation layer.')
    parser.add_argument('--knn', default=50, type=int, help='knn.')

    args = parser.parse_args()

    #10 fixed seeds for splits
    SEEDS=[1234]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    print(args)
    print("---------------------------------------------")

    gnn_name = args.net
    if gnn_name == 'GCN':
        Net = GCN_Net
    elif gnn_name == 'GAT':
        Net = GAT_Net
    elif gnn_name == 'APPNP':
        Net = APPNP_Net
    elif gnn_name == 'ChebNet':
        Net = ChebNet
    elif gnn_name == 'GPRGNN':
        Net = GPRGNN
    elif gnn_name == 'BernNet':
        Net = BernNet
    elif gnn_name =='MLP':
        Net = MLP


    if args.dataset == 'squirrel':
        buffer = DataLoader('Cora')
        data = buffer[0]
    else:
        buffer = DataLoader(args.dataset)
        data = buffer[0]
    #print(data)
    

    if args.dataset.lower() in ['chameleon', 'squirrel']:
        with open(f'data/{args.dataset.lower()}_data.pickle', 'rb') as handle:
            data1 = pickle.load(handle)
        features = data1["features"]
        labels = data1["labels"]
        idx_train = data1["idx_train"]
        idx_val = data1["idx_val"]
        idx_test = data1["idx_test"]
        adj_mtx = load_npz(f'data/{args.dataset.lower()}.npz')
    else:
        data1 = Dataset(root='./data/', name=args.dataset.lower(), setting='prognn')
        labels = data1.labels
        features = data1.features.todense()
        idx_train, idx_val, idx_test = data1.idx_train, data1.idx_val, data1.idx_test
        adj_mtx = data1.adj

    data_size = buffer[0].y.shape[0]
    labels = torch.as_tensor(labels, dtype=torch.long)
    features = torch.from_numpy(features)

    data.y = labels
    data.x = features
    data.edge_index = adj2edge(adj_mtx)
    
    data.train_mask = create_mask(idx_train, data_size)
    data.val_mask = create_mask(idx_val, data_size)
    data.test_mask = create_mask(idx_test, data_size)


    percls_trn = int(round(args.train_rate*len(data.y)/(labels.max().item() + 1)))
    val_lb = int(round(args.val_rate*len(data.y)))


    for The_spade_number in [2]:
        #print('the SPADE number is:{}'.format(The_spade_number))
        # args.seed=SEEDS
        test_acc, best_val_acc, theta_0,time_run, TopEig, TopNodeList, model, heteroNodeList = RunExp(args, data, Net, percls_trn, val_lb,The_spade_number)

        select_percent = 0.01
        print('model:{}; dataset:{}; the TopEig: {:0.2f}'.format(gnn_name,args.dataset,TopEig))
        idx_nonrobust = TopNodeList[:int(TopNodeList.shape[0]*select_percent)]
        idx_robust = TopNodeList[-int(TopNodeList.shape[0]*select_percent):]
        hetero_nonrobust = heteroNodeList[:int(heteroNodeList.shape[0]*select_percent)]
        hetero_robust = heteroNodeList[-int(heteroNodeList.shape[0]*select_percent):]
        idx_random = np.random.choice(TopNodeList, int(TopNodeList.shape[0]*select_percent), replace=False)

        orig_adj_mtx = copy.copy(adj_mtx)
        G_orig = nx.from_scipy_sparse_matrix(orig_adj_mtx)
        # Calculate the average degree
        robust_degrees = np.array([G_orig.degree(node_id) for node_id in idx_robust])
        robust_degrees = np.mean(robust_degrees)

        nonrobust_degrees = np.array([G_orig.degree(node_id) for node_id in idx_nonrobust])
        nonrobust_degrees = np.mean(nonrobust_degrees)


        #run edge perturbation
        #split robust, nonrobust, and random in case they influence each other
        device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        edge_index = data.edge_index
        adj_mtx = edge2adj(edge_index)
        G = adj2graph(adj_mtx)
        robust_acc,orig_out = my_evaluate_select(data, idx_robust.copy(), model)
        random_acc,_ = my_evaluate_select(data, idx_random.copy(), model)
        nonrobust_acc,_ = my_evaluate_select(data, idx_nonrobust.copy(), model)
        hetero_nonrobust_acc,_ = my_evaluate_select(data, hetero_nonrobust.copy(), model)
        hetero_robust_acc,_ = my_evaluate_select(data, hetero_robust.copy(), model)
        #print(' edge randomPT:{} ; SPADE_robust: {:0.2f}; hetero_robust: {:0.2f}; random: {:0.2f}; SPADE_nonrobust: {:0.2f}; hetero_nonrobust: {:0.2f}'.format(0, robust_acc, hetero_robust_acc, random_acc, nonrobust_acc,hetero_nonrobust_acc))
        x_left=[]
        x_right=[]
        y_left1=[]
        y_left2=[]
        y_left3=[]
        y_left4=[]
        y_right1=[]
        y_right2=[]
        y_right3=[]
        y_right4=[]
        for i in range(5):
            G_pt = random_edgePT(G, i+1,labels,idx_robust)
            adj = nx.adjacency_matrix(G_pt)
            data_pt = copy.copy(data)
            data_pt.edge_index = adj2edge(adj)
            robust_acc,spade_robust = my_evaluate_select(data_pt.to(device), idx_robust.copy(), model)

            G_pt = random_edgePT(G, i+1,labels,idx_nonrobust)
            adj = nx.adjacency_matrix(G_pt)
            data_pt = copy.copy(data)
            data_pt.edge_index = adj2edge(adj)
            nonrobust_acc,spade_nonrobust = my_evaluate_select(data_pt.to(device), idx_nonrobust.copy(), model)

            G_pt = random_edgePT(G, i+1,labels,hetero_robust)
            adj = nx.adjacency_matrix(G_pt)
            data_pt = copy.copy(data)
            data_pt.edge_index = adj2edge(adj)
            hetero_robust_acc,hetero_robust_out = my_evaluate_select(data_pt.to(device), hetero_robust.copy(), model)

            G_pt = random_edgePT(G, i+1,labels,hetero_nonrobust)
            adj = nx.adjacency_matrix(G_pt)
            data_pt = copy.copy(data)
            data_pt.edge_index = adj2edge(adj)
            hetero_nonrobust_acc,hetero_nonrobust_out = my_evaluate_select(data_pt.to(device), hetero_nonrobust.copy(), model)

            G_pt = random_edgePT(G, i+1,labels,idx_random)
            adj = nx.adjacency_matrix(G_pt)
            data_pt = copy.copy(data)
            data_pt.edge_index = adj2edge(adj)
            random_acc,_ = my_evaluate_select(data_pt.to(device), idx_random.copy(), model)

            #print('edge randomPT:{} ; SPADE_robust: {:0.2f}; hetero_robust: {:0.2f}; random: {:0.2f}; SPADE_nonrobust: {:0.2f}; hetero_nonrobust: {:0.2f}'.format(i+1, robust_acc, hetero_robust_acc, random_acc, nonrobust_acc,hetero_nonrobust_acc))
            cos_spade_nonrobust = calculate_cosine(orig_out,spade_nonrobust,idx_nonrobust.copy())
            cos_spade_robust = calculate_cosine(orig_out,spade_robust,idx_robust.copy())

            ED_spade_nonrobust = calculate_kl(orig_out,spade_nonrobust,idx_nonrobust.copy())
            ED_spade_robust = calculate_kl(orig_out,spade_robust,idx_robust.copy())

            cos_heter_nonrobust = calculate_cosine(orig_out,hetero_nonrobust_out,hetero_nonrobust.copy())
            cos_heter_robust = calculate_cosine(orig_out,hetero_robust_out,hetero_robust.copy())

            print(' SPADE cosine similarity: robust:{:0.2f}; nonrobust:{:0.2f}'.format(cos_spade_robust,cos_spade_nonrobust))
            print('SPADE KL: robust:{:0.2f}; nonrobust:{:0.2f}'.format(ED_spade_robust,ED_spade_nonrobust))

            x_left.append(i+1)
            y_left1.append(cos_spade_robust)
            y_left2.append(cos_spade_nonrobust)
            y_left3.append(ED_spade_robust)
            y_left4.append(ED_spade_nonrobust)
    
        #run feature perturtation by add gaussian noise
        robust_acc,_ = my_evaluate_select(data, idx_robust.copy(), model)
        nonrobust_acc,_ = my_evaluate_select(data, idx_nonrobust.copy(), model)
        random_acc,_ = my_evaluate_select(data, idx_random.copy(), model)
        #print('feat randomPT:{} ; robust: {:0.2f}; random: {:0.2f}; nonrobust: {:0.2f}'.format(0.0, robust_acc, random_acc, nonrobust_acc))
        for i in [0.4,0.8,1.2,1.6,2.0]:
            PTfeature = featurePT(features,i)
            data_pt = copy.copy(data)
            data_pt.x = PTfeature
            robust_acc,spade_robust = my_evaluate_select(data_pt.to(device), idx_robust.copy(), model)
            nonrobust_acc,spade_nonrobust = my_evaluate_select(data_pt.to(device), idx_nonrobust.copy(), model)
            random_acc,_ = my_evaluate_select(data_pt.to(device), idx_random.copy(), model)
            hetero_robust_acc,hetero_robust_out = my_evaluate_select(data_pt.to(device), hetero_robust.copy(), model)
            hetero_nonrobust_acc,hetero_nonrobust_out = my_evaluate_select(data_pt.to(device), hetero_nonrobust.copy(), model)

            #print(' feat randomPT:{} ; SPADE_robust: {:0.2f}; hetero_robust: {:0.2f}; random: {:0.2f}; SPADE_nonrobust: {:0.2f}; hetero_nonrobust: {:0.2f}'.format(i, robust_acc, hetero_robust_acc, random_acc, nonrobust_acc,hetero_nonrobust_acc))

            cos_spade_nonrobust = calculate_cosine(orig_out,spade_nonrobust,idx_nonrobust.copy())
            cos_spade_robust = calculate_cosine(orig_out,spade_robust,idx_robust.copy())

            ED_spade_nonrobust = calculate_kl(orig_out,spade_nonrobust,idx_nonrobust.copy())
            ED_spade_robust = calculate_kl(orig_out,spade_robust,idx_robust.copy())

            cos_heter_nonrobust = calculate_cosine(orig_out,hetero_nonrobust_out,hetero_nonrobust.copy())
            cos_heter_robust = calculate_cosine(orig_out,hetero_robust_out,hetero_robust.copy())


            print(' SPADE cosine similarity: robust:{:0.2f}; nonrobust:{:0.2f}'.format(cos_spade_robust,cos_spade_nonrobust))
            print('SPADE KL: robust:{:0.2f}; nonrobust:{:0.2f}'.format(ED_spade_robust,ED_spade_nonrobust))
            x_right.append(i)
            y_right1.append(cos_spade_robust)
            y_right2.append(cos_spade_nonrobust)
            y_right3.append(ED_spade_robust)
            y_right4.append(ED_spade_nonrobust)

    plot_cosine(x_left,x_right,y_left1,y_left2,y_right1, y_right2,args.dataset,args.net,robust_degrees,nonrobust_degrees,topeig=TopEig)
    plot_ED(x_left,x_right,y_left3,y_left4,y_right3, y_right4,args.dataset,args.net,robust_degrees,nonrobust_degrees,topeig=TopEig)