import torch
import numpy as np

from scipy.sparse.csgraph import dijkstra
from scipy.sparse import csr_matrix

from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import normalize

def construct_affinity_matrix(data, kernel="euclidean", rbf_norm=True): 
    """
    Constructs the affinity matrix for the data matrix.

    Parameters
    ----------
    data : torch.tensor
        Data matrix.
    kernel : str
        Kernel to use. Options are "euclidean" or "geodesic".
    rbf : bool
        If True, use the RBF kernel. Otherwise, l2 and then max norm the pairwise dist matrix.
    """
    if kernel == "euclidean":
        pairwise_dist = torch.cdist(data, data)
    elif kernel == "geodesic":
        pairwise_dist = calculate_geodesic(data)

    if not rbf_norm:
        pairwise_dist = torch.from_numpy(normalize(pairwise_dist, norm="l2")).to(torch.float)
        aff_matrix = pairwise_dist.triu() + pairwise_dist.triu(1).T
    else:
        sigma = torch.median(pairwise_dist[pairwise_dist != 0])
        pairwise_dist = pairwise_dist.triu() + pairwise_dist.triu(1).T
        aff_matrix = torch.exp(-torch.pow(pairwise_dist, 2) / (2 * sigma**2)) # RBF kernel
        aff_matrix.fill_diagonal_(0)
        aff_matrix = aff_matrix / aff_matrix.max()

    return aff_matrix

def calculate_geodesic(data, k=30):
    """
    Calculates the geodesic distance matrix for the data matrix.

    Parameters
    ----------
    data : torch.tensor
        Data matrix.
    k : int
        Number of nearest neighbors in kNN graph to approximate geodesic distance matrix
        using Dijkstra's algorithm for shortest path.
    """
    knn = kneighbors_graph(data, k, mode="connectivity", metric="correlation", include_self=True)

    # Compute shortest distances
    shortest_path = dijkstra(csgraph=csr_matrix(knn), directed=False, return_predecessors=False)
    
    # Deal with unconnected stuff (infinities):
    max_dist = np.nanmax(shortest_path[shortest_path != np.inf])
    shortest_path[shortest_path > max_dist] = max_dist
    # Finnally, normalize the distance matrix:
    graph_dist = shortest_path / shortest_path.max()
    
    return torch.from_numpy(graph_dist).to(torch.float)

def compute_log_scales(lmin, lmax, Nscales, t1=1, t2=2):
    """
    Compute logarithm scales for wavelets.

    Parameters
    ----------
    lmin : float
        Smallest non-zero eigenvalue.
    lmax : float
        Largest eigenvalue, i.e. :py:attr:`pygsp.graphs.Graph.lmax`.
    Nscales : int
        Number of scales.

    Returns
    -------
    scales : ndarray
        List of scales of length Nscales.
    """
    scale_min = t1 / lmax
    scale_max = t2 / lmin
    return np.exp(torch.linspace(torch.log(scale_max), torch.log(scale_min), Nscales))

def compute_kernel(Cx, Cy, h):
    '''
    compute Gaussian kernel matrices
    Parameters
    ----------
    Cx: source pairwise distance matrix
    Cy: target pairwise distance matrix
    h : bandwidth
    Returns
    ----------
    Kx: source kernel
    Ky: targer kernel
    '''
    std1 = torch.sqrt((Cx**2).mean() / 2)
    std2 = torch.sqrt((Cy**2).mean() / 2)
    h1 = h * std1
    h2 = h * std2
    # Gaussian kernel (without normalization)
    Kx = torch.exp(-(Cx / h1)**2 / 2)
    Ky = torch.exp(-(Cy / h2)**2 / 2)
    return Kx, Ky

def calculate_entropy(Kx):
    """
    Calculate entropy of a kernel matrix using Kernel Density Estimation.

    Parameters
    ----------
    Kx : torch.tensor
        Kernel matrix.

    Returns
    ----------
    entropy : float
        Entropy of the kernel matrix.
    """
    if (Kx == 0.0).all():
        return 0.0
    
    if Kx.ndim < 2:
        f_x = Kx
    else:
        f_x = Kx.sum(1) / Kx.shape[1]

    log_f_x = torch.log(f_x)

    entropy = -torch.sum(f_x * log_f_x)

    return entropy

def make_symmetric(F):
    """
    Symmetrise a matrix. We assume that the matrix is upper triangular.
    We use this to symmetrise the filter matrix F.

    Parameters
    ----------
    F : torch.tensor
        Matrix to symmetrise.

    Returns
    ----------
    F : torch.tensor
        Symmetric matrix.
    """
    F = F.triu() + F.triu(1).transpose(-1, -2)
    return F

def get_entropy_F(wavelet_coeffs_X1, wavelet_coeffs_X2, h=0.4):
    """
    Get entropy of wavelet coefficients for each scale.

    Parameters
    ----------
    wavelet_coeffs_X1 : torch.tensor
        Wavelet coefficients for source data matrix.
    wavelet_coeffs_X2 : torch.tensor
        Wavelet coefficients for target data matrix.
    h : float
        Bandwidth parameter for Gaussian kernel.

    Returns
    ----------
    F1 : torch.tensor (num_scales, )
        Entropy of wavelet coefficients for source data matrix.
    F2 : torch.tensor (num_scales, )
        Entropy of wavelet coefficients for target data matrix.
    """
    num_scales, ns, ns = wavelet_coeffs_X1.shape
    num_scales, nt, nt = wavelet_coeffs_X2.shape
    wavelet_coeffs_X1 = wavelet_coeffs_X1.cuda()
    wavelet_coeffs_X2 = wavelet_coeffs_X2.cuda()

    F1 = torch.zeros((num_scales, ns)).cuda()
    F2 = torch.zeros((num_scales, nt)).cuda()
    for scale_i in range(num_scales):
        if (wavelet_coeffs_X1[scale_i] == torch.zeros((ns, ns)).cuda()).all() or (wavelet_coeffs_X2[scale_i] == torch.zeros((nt, nt)).cuda()).all():
            continue
        Kx, Ky = compute_kernel(
            wavelet_coeffs_X1[scale_i], 
            wavelet_coeffs_X2[scale_i], 
            h
        )
        ent_x = calculate_entropy(Kx)
        ent_y = calculate_entropy(Ky)
        F1[scale_i] = ent_x
        F2[scale_i] = ent_y

    F1 = F1 / F1.sum()
    F2 = F2 / F2.sum()

    return F1, F2