import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_undirected, remove_self_loops, add_self_loops

from logger import *
from dataset import load_dataset
from data_utils import eval_acc, eval_rocauc, load_fixed_splits
from eval import *
from parse import parse_method, parser_add_main_args

from my_utils.utils import random_edgePT,hnsw,featurePT, spectral_embedding_eig,SPF,construct_weighted_adj,spade_nonetworkx
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt
from torch.nn.functional import cosine_similarity

def to_unweighted_csr(adj_matrix_csr):
    # Get the indices and indptr from the original matrix
    indices = adj_matrix_csr.indices
    indptr = adj_matrix_csr.indptr
    # Create an array of 1s for the values (all edges have weight 1)
    unweighted_data = np.ones(len(indices), dtype=np.float64)
    # Create the unweighted adjacency matrix in CSR format
    unweighted_adj_matrix = csr_matrix((unweighted_data, indices, indptr), shape=adj_matrix_csr.shape)

    return unweighted_adj_matrix

def edge_index_to_csr(edge_index, num_nodes=None):
    # Ensure edge_index is on CPU
    edge_index = edge_index.cpu()
    # Extract source and target nodes
    row = edge_index[0].numpy()
    col = edge_index[1].numpy()
    # Create data array (1s for unweighted graph)
    data = np.ones(row.shape[0], dtype=np.float32)
    # Infer number of nodes if not provided
    if num_nodes is None:
        num_nodes = max(row.max(), col.max()) + 1
    # Create the CSR matrix
    adj_csr = csr_matrix((data, (row, col)), shape=(num_nodes, num_nodes))

    return adj_csr


def csr_to_edge_index(adj_csr):
    # Convert CSR to COO format for easy access to row and col indices
    adj_coo = adj_csr.tocoo()
    # Extract row and column indices
    row = adj_coo.row
    col = adj_coo.col
    # Create edge_index tensor
    edge_index = torch.tensor([row, col], dtype=torch.long)

    return edge_index


def calculate_kl(tensor1, tensor2, mymask):
    tensor1 = tensor1[mymask]
    tensor2 = tensor2[mymask]
    kl_div = F.kl_div(tensor1, torch.exp(tensor2), reduction='batchmean')
    return kl_div.detach().cpu().numpy()


def calculate_cosine(tensor1,tensor2,mymask):
    cosine_sim = cosine_similarity(tensor1[mymask], tensor2[mymask], dim=0)
    aver_cosine_sim = torch.mean(cosine_sim)

    return aver_cosine_sim.detach().cpu().numpy()


def plot_cosine(x_left,x_right,y_left1,y_left2,y_right1, y_right2,dataset_name,model_name):
    # Create subplots
    the_font_size = 15
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    # Left subplot
    ax1.plot(x_left, y_left1, label='SAGMAN Stable')
    ax1.plot(x_left, y_left2, label='SAGMAN Unstable')
    ax1.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax1.set_xlabel('Random EdgesPT', fontsize=the_font_size)
    ax1.set_ylabel('Cosine Similarities', fontsize=the_font_size)
    ax1.set_xticks(x_left)
    ax1.tick_params(labelsize=10)
    # Right subplot
    ax2.plot(x_right, y_right1, label='SAGMAN Stable')
    ax2.plot(x_right, y_right2, label='SAGMAN Unstable')
    ax2.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax2.set_xlabel('Random FeaturePT', fontsize=the_font_size)
    #ax2.set_ylabel('cosine_similarities', fontsize=the_font_size)
    ax2.set_xticks(x_right)
    ax2.tick_params(labelsize=10)
    #fig.suptitle("SPADE Score:{:0.2f}, Robust Node Degree:{:0.2f}, NonRobust Node Degree:{:0.2f}".format(topeig,robust_degrees,nonrobust_degrees), fontsize=the_font_size)
    # Show the plot
    ax1.legend(fontsize=the_font_size)
    ax2.legend(fontsize=the_font_size)
    plt.tight_layout()
    plt.savefig('./{}_{}_cos.png'.format(dataset_name,model_name), dpi=300)
    plt.clf()



def plot_ED(x_left,x_right,y_left1,y_left2,y_right1, y_right2,dataset_name,model_name):
    # Create subplots
    the_font_size = 15
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    # Left subplot
    ax1.plot(x_left, y_left1, label='SAGMAN Stable')
    ax1.plot(x_left, y_left2, label='SAGMAN Unstable')
    ax1.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax1.set_xlabel('Random EdgesPT', fontsize=the_font_size)
    ax1.set_ylabel('Kullback-Leibler Divergence', fontsize=the_font_size)
    ax1.set_xticks(x_left)
    ax1.tick_params(labelsize=10)
    # Right subplot
    ax2.plot(x_right, y_right1, label='SAGMAN Stable')
    ax2.plot(x_right, y_right2, label='SAGMAN Unstable')
    ax2.set_title('{},{}'.format(dataset_name,model_name), fontsize=the_font_size)
    ax2.set_xlabel('Random FeaturePT', fontsize=the_font_size)
    #ax2.set_ylabel('Kullback-Leibler divergence', fontsize=the_font_size)
    ax2.set_xticks(x_right)
    ax2.tick_params(labelsize=10)
    #fig.suptitle("SPADE Score:{:0.2f}, Robust Node Degree:{:0.2f}, NonRobust Node Degree:{:0.2f}".format(topeig,robust_degrees,nonrobust_degrees), fontsize=the_font_size)
    # Show the plot
    ax1.legend(fontsize=the_font_size)
    ax2.legend(fontsize=the_font_size)
    plt.tight_layout()
    plt.savefig('./{}_{}_KL.png'.format(dataset_name,model_name), dpi=300)
    plt.clf()


def my_evaluate_select(model, dataset, feature, edge_index, eval_func, mask):
    model.eval()
    out = model(feature, edge_index)
    test_acc = eval_func(
        dataset.label[mask], out[mask])
    
    return test_acc, out


def fix_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

### Parse args ###
parser = argparse.ArgumentParser(description='Training Pipeline for Node Classification')
parser_add_main_args(parser)
args = parser.parse_args()
if not args.global_dropout:
    args.global_dropout = args.dropout
print(args)

fix_seed(args.seed)

if args.cpu:
    device = torch.device("cpu")
else:
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

### Load and preprocess data ###
dataset = load_dataset(args.data_dir, args.dataset)

if len(dataset.label.shape) == 1:
    dataset.label = dataset.label.unsqueeze(1)
dataset.label = dataset.label.to(device)

split_idx_lst = load_fixed_splits(args.data_dir, dataset, name=args.dataset)

### Basic information of datasets ###
n = dataset.graph['num_nodes']
e = dataset.graph['edge_index'].shape[1]
c = max(dataset.label.max().item() + 1, dataset.label.shape[1])
d = dataset.graph['node_feat'].shape[1]

print(f"dataset {args.dataset} | num nodes {n} | num edge {e} | num node feats {d} | num classes {c}")

dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index'])
dataset.graph['edge_index'], _ = remove_self_loops(dataset.graph['edge_index'])
dataset.graph['edge_index'], _ = add_self_loops(dataset.graph['edge_index'], num_nodes=n)

dataset.graph['edge_index'], dataset.graph['node_feat'] = \
    dataset.graph['edge_index'].to(device), dataset.graph['node_feat'].to(device)

### Load method ###
model = parse_method(args, n, c, d, device)

### Loss function (Single-class, Multi-class) ###
if args.dataset in ('questions'):
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.NLLLoss()

### Performance metric (Acc, AUC) ###
if args.metric == 'rocauc':
    eval_func = eval_rocauc
else:
    eval_func = eval_acc

logger = Logger(args.runs, args)

# model.train()
# print('MODEL:', model)

### Training loop ###
for run in range(args.runs):
    if args.dataset in ('coauthor-cs', 'coauthor-physics', 'amazon-computer', 'amazon-photo'):
        split_idx = split_idx_lst[0]
    else:
        split_idx = split_idx_lst[run]
    train_idx = split_idx['train'].to(device)
    # model.reset_parameters()
    # model._global = False
    optimizer = torch.optim.Adam(model.parameters(),weight_decay=args.weight_decay, lr=args.lr)
    # best_val = float('-inf')
    # best_test = float('-inf')
    # if args.save_model:
    #     save_model(args, model, optimizer, run)

    # for epoch in range(args.local_epochs+args.global_epochs):
    #     if epoch == args.local_epochs:
    #         print("start global attention!!!!!!")
    #         if args.save_model:
    #             model, optimizer = load_model(args, model, optimizer, run)
    #         model._global = True
    #     model.train()
    #     optimizer.zero_grad()

    #     out = model(dataset.graph['node_feat'], dataset.graph['edge_index'])
    #     if args.dataset in ('questions'):
    #         if dataset.label.shape[1] == 1:
    #             true_label = F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1)
    #         else:
    #             true_label = dataset.label
    #         loss = criterion(out[train_idx], true_label.squeeze(1)[
    #             train_idx].to(torch.float))
    #     else:
    #         out = F.log_softmax(out, dim=1)
    #         loss = criterion(
    #             out[train_idx], dataset.label.squeeze(1)[train_idx])
    #     loss.backward()
    #     optimizer.step()

    #     result = evaluate(model, dataset, split_idx, eval_func, criterion, args)

    #     logger.add_result(run, result[:-1])

    #     if result[1] > best_val:
    #         best_val = result[1]
    #         best_test = result[2]
    #         if args.save_model:
    #             save_model(args, model, optimizer, run)

    #     if epoch % args.display_step == 0:
    #         print(f'Epoch: {epoch:02d}, '
    #               f'Loss: {loss:.4f}, '
    #               f'Train: {100 * result[0]:.2f}%, '
    #               f'Valid: {100 * result[1]:.2f}%, '
    #               f'Test: {100 * result[2]:.2f}%, '
    #               f'Best Valid: {100 * best_val:.2f}%, '
    #               f'Best Test: {100 * best_test:.2f}%')
    # logger.print_statistics(run)
    model._global = True
    model, optimizer = load_model(args, model, optimizer, run)
    orig_out = model(dataset.graph['node_feat'], dataset.graph['edge_index'])


    adj = edge_index_to_csr(dataset.graph['edge_index'])
    features = dataset.graph['node_feat'].cpu().numpy()
    the_k = 50
    spec_embed = spectral_embedding_eig(adj,features,use_feature=True,adj_norm=False)#spectral_embedding_eig,spectral_embedding
    neighs, distance = hnsw(spec_embed, k=the_k)
    embed_adj_mtx = construct_weighted_adj(neighs, distance)#construct_weighted_adj,construct_adj
    embed_adj_mtx, inter_edge_adj = SPF(embed_adj_mtx, 4)
    embed_adj_mtx = to_unweighted_csr(embed_adj_mtx)
    embed_edge_index = csr_to_edge_index(embed_adj_mtx).to(device)

    model.eval()
    embed_out = model(dataset.graph['node_feat'], embed_edge_index)
    TopEig, _, TopNodeList, _, L_in, L_out = spade_nonetworkx(embed_adj_mtx, embed_out.cpu().detach().numpy(), k=the_k)#spade


    #percent
    node_percent = 0.01
    #node ranking
    idx_nonrobust = TopNodeList[:int(TopNodeList.shape[0]*node_percent)]
    idx_robust = TopNodeList[-int(TopNodeList.shape[0]*node_percent):]

    result = evaluate(model, dataset, split_idx, eval_func, criterion, args)

    x_left=[]
    x_right=[]
    y_left1=[]
    y_left2=[]
    y_left3=[]
    y_left4=[]
    y_right1=[]
    y_right2=[]
    y_right3=[]
    y_right4=[]
    for i in range(5):
        pt_adj = random_edgePT(adj, i+1,dataset.label.cpu().numpy(),idx_robust)
        pt_edge_index = csr_to_edge_index(pt_adj).to(device)
        robust_acc, spade_out_robust = my_evaluate_select(model, dataset, dataset.graph['node_feat'], pt_edge_index, eval_func, idx_robust.copy())

        pt_adj = random_edgePT(adj, i+1,dataset.label.cpu().numpy(),idx_nonrobust)
        pt_edge_index = csr_to_edge_index(pt_adj).to(device)
        nonrobust_acc, spade_out_nonrobust = my_evaluate_select(model, dataset, dataset.graph['node_feat'], pt_edge_index, eval_func, idx_nonrobust.copy())

        cos_spade_nonrobust = calculate_cosine(orig_out,spade_out_nonrobust,idx_nonrobust.copy())
        cos_spade_robust = calculate_cosine(orig_out,spade_out_robust,idx_robust.copy())

        ED_spade_nonrobust = calculate_kl(orig_out,spade_out_nonrobust,idx_nonrobust.copy())
        ED_spade_robust = calculate_kl(orig_out,spade_out_robust,idx_robust.copy())

        x_left.append(i+1)
        y_left1.append(cos_spade_robust)
        y_left2.append(cos_spade_nonrobust)
        y_left3.append(ED_spade_robust)
        y_left4.append(ED_spade_nonrobust)

    for i in [0.4,0.8,1.2,1.6,2.0]:
        PTfeature = featurePT(dataset.graph['node_feat'].cpu(),i)
        PTfeature = PTfeature.to(device)
        model.eval()
        PT_out = model(PTfeature, dataset.graph['edge_index'])
        robust_acc,_ = my_evaluate_select(model, dataset, PTfeature, dataset.graph['edge_index'], eval_func, idx_robust.copy())
        nonrobust_acc,_ = my_evaluate_select(model, dataset, PTfeature, dataset.graph['edge_index'], eval_func, idx_nonrobust.copy())

        cos_spade_nonrobust = calculate_cosine(orig_out,PT_out,idx_nonrobust.copy())
        cos_spade_robust = calculate_cosine(orig_out,PT_out,idx_robust.copy())

        ED_spade_nonrobust = calculate_kl(orig_out,PT_out,idx_nonrobust.copy())
        ED_spade_robust = calculate_kl(orig_out,PT_out,idx_robust.copy())

        x_right.append(i)
        y_right1.append(cos_spade_robust)
        y_right2.append(cos_spade_nonrobust)
        y_right3.append(ED_spade_robust)
        y_right4.append(ED_spade_nonrobust)


    plot_ED(x_left,x_right,y_left3,y_left4,y_right3, y_right4,args.dataset,'polynormer')




# results = logger.print_statistics()
### Save results ###
# save_result(args, results)

