import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
from scipy.sparse import csr_matrix
import networkx as nx
from sklearn.cluster import KMeans
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score
import random
import os.path as osp
import numpy.ctypeslib as ctl
from ctypes import c_int

def scipy_sparse_to_sparse_tensor(sparse_mx):
    '''
    Convert a scipy sparse matrix to a torch sparse tensor.

    Parameters
    ----------
    sparse_mx : scipy.sparse_matrix
        Sparse matrix to convert.

    Returns
    -------
    sparse_tensor: torch.Tensor in sparse form
        A tensor stored in sparse form.
    '''
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)
def sparse_tensor_to_scipy_sparse(sparse_tensor):
    '''
    Convert a torch sparse tensor to a scipy sparse matrix.

    Parameters
    ----------
    sparse_tensor : torch.Tensor in sparse form
        A tensor stored in sparse form to convert.

    Returns
    -------
    sparse_mx : scipy.sparse_matrix
        Sparse matrix.

    '''
    sparse_tensor = sparse_tensor.cpu()
    row = sparse_tensor.coalesce().indices()[0].numpy()
    col = sparse_tensor.coalesce().indices()[1].numpy()
    values = sparse_tensor.coalesce().values().numpy()
    return sp.coo_matrix((values, (row, col)), shape=sparse_tensor.shape)
def scipy_sparse_mat_to_torch_sparse_tensor(sparse_mx):
    """
    convert scipy.sparse matrix torch sparse tensor.
    """
    if not isinstance(sparse_mx, sp.coo_matrix):
        sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def adj_to_symmetric_norm(adj, r=0.5):
    adj = adj + adj.transpose() + sp.eye(adj.shape[0])
    # adj = adj + sp.eye(adj.shape[0])
    adj = adj.tocoo()
    row, col = torch.tensor(adj.row, dtype=torch.long), torch.tensor(adj.col, dtype=torch.long)
    edge_weight = torch.ones(len(adj.data))
    adj = csr_matrix((edge_weight.numpy(), (row.numpy(), col.numpy())), shape=(adj.shape[0], adj.shape[0]))

    degrees = np.array(adj.sum(1))
    r_inv_sqrt_left = np.power(degrees, r - 1).flatten()
    r_inv_sqrt_left[np.isinf(r_inv_sqrt_left)] = 0.
    r_mat_inv_sqrt_left = sp.diags(r_inv_sqrt_left)

    r_inv_sqrt_right = np.power(degrees, -r).flatten()
    r_inv_sqrt_right[np.isinf(r_inv_sqrt_right)] = 0.
    r_mat_inv_sqrt_right = sp.diags(r_inv_sqrt_right)

    adj_normalized = adj.dot(r_mat_inv_sqrt_left).transpose().dot(r_mat_inv_sqrt_right)
    return adj_normalized

def adj_to_symmetric_norm_tensor(adj, r=0.5):
    adj = torch.tril(adj) + torch.tril(adj.T, -1)
    adj = adj + torch.eye(adj.size(0), device=adj.device)

    degrees = adj.sum(dim=1)
    
    r_inv_sqrt_left = torch.pow(degrees, r - 1)
    r_inv_sqrt_left[torch.isinf(r_inv_sqrt_left)] = 0.0
    r_mat_inv_sqrt_left = torch.diag(r_inv_sqrt_left)

    r_inv_sqrt_right = torch.pow(degrees, -r)
    r_inv_sqrt_right[torch.isinf(r_inv_sqrt_right)] = 0.0
    r_mat_inv_sqrt_right = torch.diag(r_inv_sqrt_right)

    adj_normalized = torch.mm(torch.mm(r_mat_inv_sqrt_left, adj), r_mat_inv_sqrt_right)

    return adj_normalized

def accuracy(output, labels, return_idx=False):
    pred = output.max(1)[1].type_as(labels)
    correct = pred.eq(labels).double()
    if not return_idx:
        return (correct.sum() / len(labels) * 100.0).item()
    else:
        return (correct.sum() / len(labels) * 100.0).item(), np.where(correct.cpu()==1)

def node_cls_train(model, train_idx, labels, device, optimizer, loss_fn):
    model.train()
    optimizer.zero_grad()
    train_output = model.model_forward(train_idx, device)
    loss_train = loss_fn(train_output, labels[train_idx]) 
    acc_train = accuracy(train_output, labels[train_idx])
    loss_train.backward()
    optimizer.step()
    return loss_train.item(), acc_train

def node_cls_evaluate(model, val_idx, test_idx, labels, device):
    model.eval()
    val_output = model.model_forward(idx=val_idx, device=device)
    test_output = model.model_forward(idx=test_idx, device=device)
    acc_val = accuracy(val_output, labels[val_idx])
    acc_test = accuracy(test_output, labels[test_idx])
    return acc_val, acc_test

def cosine_sim(vec1, vec2):
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

def group_cos_sim(group_v):
    mean_v = np.mean(group_v, axis=0)
    cos_sim = [cosine_sim(mean_v, v) for v in group_v]
    mean_cos_sim = np.mean(cos_sim)
    return mean_cos_sim

def get_rankings(lst):
    sorted_indices = np.argsort(lst)  
    rankings = np.empty_like(sorted_indices)  
    for rank, index in enumerate(sorted_indices):
        rankings[index] = rank + 1
    return rankings
def normalize_tensor(adj, add_loop=True):
    adj_loop = adj + torch.eye(adj.shape[0]) if add_loop else adj
    rowsum = adj_loop.sum(1)
    r_inv = rowsum.pow(-1/2).flatten()
    r_inv[torch.isinf(r_inv)] = 0.
    r_mat_inv = torch.diag(r_inv)
    A = r_mat_inv @ adj_loop
    A = A @ r_mat_inv
    return A
def sym_adj(adj):
    sym_adj = adj
    sym_adj = np.maximum(sym_adj, sym_adj.T)
    np.fill_diagonal(sym_adj, 1)
    return sym_adj

def edgeindex(adj):
    adj = adj + adj.transpose() + sp.eye(adj.shape[0])
    adj = adj.tocoo()
    row, col = torch.tensor(adj.row, dtype=torch.long), torch.tensor(adj.col, dtype=torch.long)
    edge_weight = torch.ones(len(adj.data))
    adj = csr_matrix((edge_weight.numpy(), (row.numpy(), col.numpy())), shape=(adj.shape[0], adj.shape[0]))
    return adj.nonzero()