import hnswlib
import numpy as np
from scipy.sparse import coo_matrix, diags, identity, csr_matrix
from julia.api import Julia
import scipy.sparse.linalg as sla
import networkx as nx
#import grass_mtx as mtx
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.preprocessing import normalize
from scipy.sparse.linalg import svds
from torch_sparse import SparseTensor
import torch
import random
import copy
from scipy.sparse.linalg import eigsh
from scipy.sparse.csgraph import laplacian

def adj2laplacian(A):
    D = diags(np.squeeze(np.asarray(A.sum(axis=1))), 0)
    L = D - A + identity(A.shape[0]).multiply(1e-6)

    return L

def laplacian2adj(L):
    A = copy.copy(L)
    A = np.absolute(A)
    A.setdiag(0)
    A.eliminate_zeros()

    return A

def julia_eigs(l_in, l_out, num_eigs):
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include("./my_utils/eigen.jl")
    print('Generate eigenpairs')
    eigenvalues, eigenvectors = Main.main(l_in, l_out, num_eigs)

    return eigenvalues.real, eigenvectors.real

def py_eigs(l_in, l_out, num_eigs):
    Dxy, Uxy = sla.eigs(l_in, num_eigs, l_out)
    return Dxy.real, Uxy.real

def GetRiemannianDist(Gx, Gy, Lx, Ly, num_eigs): 
    # Gy not updated 
    Lx = Lx.asfptype()
    Ly = Ly.asfptype()
    Dxy, Uxy = julia_eigs(Lx, Ly, num_eigs)
    num_node_tot = Uxy.shape[0]
    TopEig=max(Dxy)
    NodeDegree=Lx.diagonal()
    num_edge_tot=len(Gx.edges) # number of total edges  
    Zpq=np.zeros((num_edge_tot,));# edge embedding distance
    p = np.array(Gx.edges)[:,0];# one end node of each edge
    q = np.array(Gx.edges)[:,1];# another end node of each edge
    for i in np.arange(0,num_eigs):
    #for i in np.arange(Uxy.shape[1]-2, Uxy.shape[1]):
        Zpq = Zpq + np.power(Uxy[p,i]-Uxy[q,i], 2)*Dxy[i]
    Zpq = Zpq/max(Zpq)

    node_score=np.zeros((num_node_tot,))        
    for i in np.arange(0,num_edge_tot):
        node_score[p[i]]=node_score[p[i]]+Zpq[i]
        node_score[q[i]]=node_score[q[i]]+Zpq[i]
    node_score=node_score/NodeDegree
    node_score=node_score/np.amax(node_score)

    TopNodeList = np.flip(node_score.argsort(axis=0))
    TopEdgeList=np.column_stack((p,q))[np.flip(Zpq.argsort(axis=0)),:]

    return TopEig, TopEdgeList, TopNodeList, node_score


def construct_adj(neighs, weight):
    dim = neighs.shape[0]
    k = neighs.shape[1] - 1

    idx0 = np.asarray(list(range(dim)))
    idx1 = neighs[:,0]
    mismatch_idx = ~np.isclose(idx0, idx1, rtol=1e-6)
    neighs[mismatch_idx, 1:] = neighs[mismatch_idx, :k]
    row = (np.repeat(idx0.reshape(-1,1), k, axis=1)).reshape(-1,)
    col = neighs[:,1:].reshape(-1,)
    all_row = np.concatenate((row, col), axis=0)
    all_col = np.concatenate((col, row), axis=0)
    data = np.ones(all_row.shape[0])
    adj = csr_matrix((data, (all_row, all_col)), shape=(dim, dim))
    adj.data[:] = 1
    lap = laplacian(adj, normed=False)
    G = nx.from_scipy_sparse_matrix(adj)

    return adj, lap, G

def construct_weighted_adj(neighs, distances):
    dim = neighs.shape[0]
    k = neighs.shape[1] - 1
    weights = np.exp(-distances)

    idx0 = np.asarray(list(range(dim)))
    idx1 = neighs[:,0]
    mismatch_idx = ~np.isclose(idx0, idx1, rtol=1e-6)
    neighs[mismatch_idx, 1:] = neighs[mismatch_idx, :k]
    row = (np.repeat(idx0.reshape(-1,1), k, axis=1)).reshape(-1,)
    col = neighs[:,1:].reshape(-1,)
    # calculate weights for each edge
    edge_weights = weights[:,1:].reshape(-1,)
    all_row = np.concatenate((row, col), axis=0)
    all_col = np.concatenate((col, row), axis=0)
    all_data = np.concatenate((edge_weights, edge_weights), axis=0)  # use weights instead of ones
    adj = csr_matrix((all_data, (all_row, all_col)), shape=(dim, dim))
    G = nx.from_scipy_sparse_matrix(adj)
    # construct a graph from the adjacency matrix
    lap = laplacian(adj, normed=False)

    return adj, lap, G



def hnsw(features, k=10, ef=100, M=48):
    num_samples, dim = features.shape

    p = hnswlib.Index(space='l2', dim=dim)
    p.init_index(max_elements=num_samples, ef_construction=ef, M=M)
    labels_index = np.arange(num_samples)
    p.add_items(features, labels_index)
    p.set_ef(ef)

    neighs, distance = p.knn_query(features, k+1)
  
    return neighs, distance

def spade(adj_in, data_output, k=10, num_eigs=2): 

    G_in = nx.from_scipy_sparse_matrix(adj_in)
    neighs, distance = hnsw(data_output, k)
    adj_out, _, G_out = construct_weighted_adj(neighs, distance)#construct_adj, construct_weighted_adj

    assert nx.is_connected(G_in), "input graph is not connected"
    assert nx.is_connected(G_out), "output graph is not connected"

    #adj_in = SPF(adj_in, 4)
    L_in = laplacian(adj_in, normed=False)#.tocsr()#adj2laplacian(adj_in)

    adj_out = SPF(adj_out, 4)
    #adj_out.data = np.ones_like(adj_out.data)#weighted to unweighted
    L_out = laplacian(adj_out, normed=False)#.tocsr()#adj2laplacian(adj_out)
  
    TopEig, TopEdgeList, TopNodeList, node_score = GetRiemannianDist(G_in, G_out, L_in, L_out, num_eigs)# full function
    return TopEig, TopEdgeList, TopNodeList, node_score, L_in, L_out


def embedding_normalize(embedding, norm):
    if norm == "unit_vector":
        return normalize(embedding, axis=1)
    elif norm == "standardize":
        scaler = StandardScaler()
        return scaler.fit_transform(embedding)
    elif norm == "minmax":
        scaler = MinMaxScaler()
        return scaler.fit_transform(embedding)
    else:
        return embedding
    
def normal_adj(adj):
    adj = SparseTensor.from_scipy(adj)
    deg = adj.sum(dim=1).to(torch.float)
    D_isqrt = deg.pow(-0.5)
    D_isqrt[D_isqrt == float('inf')] = 0
    DAD = D_isqrt.view(-1,1) * adj * D_isqrt.view(1,-1)

    return DAD.to_scipy(layout='csr')

def hyperEF(L, level, grass):
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include("./HyperEF1.jl")
    idx  = Main.HyperEF1( L, level, grass)

def spectral_embedding(adj_mtx,features,use_feature=True,embedding_norm=None,adj_norm=True):
    adj_mtx = adj_mtx.asfptype()
    num_nodes = adj_mtx.shape[0]
    if adj_norm:
        adj_mtx = normal_adj(adj_mtx)
    U, S, Vt = svds(adj_mtx, 50)

    spec_embed = np.sqrt(S.reshape(1,-1))*U
    spec_embed = embedding_normalize(spec_embed, embedding_norm)
    if use_feature:
        feat_embed = adj_mtx @ (adj_mtx @ features)/2
        feat_embed = embedding_normalize(feat_embed, embedding_norm)
        spec_embed = np.concatenate((spec_embed, feat_embed), axis=1)
    return spec_embed

def spectral_embedding_eig(adj_mtx,features,use_feature=True,embedding_norm=None,eig_julia=False):
    adj_mtx = adj_mtx.asfptype()
    L_mtx = laplacian(adj_mtx, normed=False)

    if not eig_julia:
        S, U = eigsh(L_mtx,k=50,which='SM', maxiter=500000)
    else:
        jl = Julia(compiled_modules=False)
        from julia import Main
        Main.include("./my_utils/eigen.jl")
        S, U = Main.not_main(L_mtx.tocoo(), 50)
    
    spec_embed = np.empty_like(U)
    for i in range(U.shape[1]):
        spec_embed[:, i] = U[:, i] / np.sqrt(S[i])
    #spec_embed = embedding_normalize(spec_embed, embedding_norm)
    if use_feature:
        feat_embed = adj_mtx @ (adj_mtx @ features)/2
        #feat_embed = embedding_normalize(feat_embed, embedding_norm)
        spec_embed = np.concatenate((spec_embed, feat_embed), axis=1)

        
    return spec_embed


def adj2graph(adj):
    G = nx.from_scipy_sparse_matrix(adj)
    return G


def SPF(adj, L, ICr=0.11):
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include("./my_utils/SPF.jl")
    agj_c = Main.SPF(adj, L, ICr)

    return agj_c

def heterophily_score(adj_matrix: csr_matrix, model_output: torch.Tensor):
    num_nodes = model_output.shape[0]
    labels = torch.argmax(model_output, dim=1)
    assert adj_matrix.shape[0] == len(labels), "Number of nodes in adjacency matrix and labels must match"
    heterophily_scores = []

    for node_idx in range(num_nodes):
        neighbors = adj_matrix[node_idx].indices
        num_neighbors = len(neighbors)
        if num_neighbors == 0:
            heterophily_scores.append(0)
            continue
        node_label = labels[node_idx].item()
        diff_labels_count = sum([1 for neighbor in neighbors if labels[neighbor].item() != node_label])
        heterophily_scores.append(diff_labels_count / num_neighbors)

    heteroNodeList = sorted(range(len(heterophily_scores)), key=lambda i: heterophily_scores[i], reverse=True)

    return np.array(heteroNodeList),heterophily_scores


def jaccard_similarity(arr1, arr2):
    set1 = set(arr1)
    set2 = set(arr2)
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union > 0 else 0


def edge2adj(edge_index):
    # Find the number of nodes (assuming the node indices are 0-based)
    num_nodes = edge_index.max().item() + 1
    # Convert the edge index tensor to a NumPy array
    edge_index_np = edge_index.numpy()
    # Create the sparse adjacency matrix using scipy.sparse.coo_matrix
    coo_adj_matrix = coo_matrix((torch.ones(edge_index.shape[1]), (edge_index_np[0], edge_index_np[1])), shape=(num_nodes, num_nodes))
    # Convert the COO matrix to CSR format
    csr_adj_matrix = coo_adj_matrix.tocsr()
    return csr_adj_matrix

def adj2edge(adj):
    # Convert the CSR matrix to COO format
    coo_adj_matrix = adj.tocoo()
    # Get the row (source nodes) and col (target nodes) attributes of the COO matrix
    source_nodes = coo_adj_matrix.row
    target_nodes = coo_adj_matrix.col
    # Combine source and target nodes as edge index
    edge_index_np = np.vstack((source_nodes, target_nodes))
    # Convert the NumPy array back to a PyTorch tensor
    edge_index = torch.tensor(edge_index_np, dtype=torch.long)
    return edge_index

def featurePT(x,beta):
    samples_num,dimen = x.shape
    std_dev = torch.std(x)
    noise = torch.randn(samples_num, dimen) * std_dev
    x += noise * beta
    return x



def random_edgePT(orig_adj_mtx, the_p,label,node_index):
    graph = adj2graph(orig_adj_mtx)
    perturbed_graph = graph.copy()
    for node in node_index:
        other_nodes = [n for n in perturbed_graph.nodes() if n != node and label[n] != label[node]]
        nodes_to_connect = random.sample(other_nodes, the_p)
        edges = [(node, n) for n in nodes_to_connect]
        perturbed_graph.add_edges_from(edges)
        # randomly select x edges to remove between nodes with the same label as node
        same_label_neighbors = [n for n in perturbed_graph.neighbors(node) if label[n] == label[node]]
        edges_to_remove = random.sample(same_label_neighbors, min(the_p, len(same_label_neighbors)))
        # remove the selected edges
        perturbed_graph.remove_edges_from([(node, n) for n in edges_to_remove])
    pt_adj = csr_matrix(nx.adjacency_matrix(perturbed_graph))
    return pt_adj

def find_edge_index_dif(edge_index1,edge_index2):
    # Transpose the edge_index tensors so that the shape is (num_edges, 2)
    edge_index1_t = edge_index1.t()
    edge_index2_t = edge_index2.t()
    # Find the unique rows in both tensors and their indices
    unique_edge_index1, indices1 = torch.unique(edge_index1_t, dim=0, return_inverse=True)
    unique_edge_index2, indices2 = torch.unique(edge_index2_t, dim=0, return_inverse=True)
    # Find the difference between the unique edge indices
    only_in_edge_index1 = unique_edge_index1[torch.isin(indices1, indices2, invert=True)]
    only_in_edge_index2 = unique_edge_index2[torch.isin(indices2, indices1, invert=True)]
    combined_edges = torch.cat((only_in_edge_index1, only_in_edge_index2), dim=0)
    # Find the node rankings for both sets of unique edges
    ranked_combined = node_ranking(combined_edges)
    combined_node_ids_rank = [node for node, _ in ranked_combined]

    return np.array(combined_node_ids_rank), ranked_combined

def node_ranking(unique_edges):
    node_counts = {}
    for edge in unique_edges:
        for node in edge:
            if node.item() in node_counts:
                node_counts[node.item()] += 1
            else:
                node_counts[node.item()] = 1       
    # Sort the nodes by their frequency in descending order
    sorted_nodes = sorted(node_counts.items(), key=lambda x: x[1], reverse=True)
    return sorted_nodes