import torch

from src.utils import construct_affinity_matrix
from src.diff.kernels import *
from src.diff.laplacian import compute_laplacian
from src.diff.cheby import compute_cheby_coeff, cheby_op

def get_D(K):
    """
    Calculate the degree matrix from the affinity matrix.

    Parameters
    ----------
    K : torch.tensor
        Affinity matrix.

    Returns
    ----------
    Q : torch.tensor
        Degree matrix.
    """
    degrees = torch.ravel(K.sum(axis=0))

    return torch.diag(degrees)

def calculate_wavelet_coeffs(
        data, n_scales, approx_order=32, 
        w_op="meyer", dist="geodesic", aff_matrix=None, rbf_norm=True
    ):
    """
    Calculate the wavelet coefficients for the data matrix.

    Parameters
    ----------
    data : torch.tensor
        Data matrix.
    n_scales : int
        Number of scales to use for wavelet transform.
    approx_order : int
        Order of Chebyshev polynomial approximation.
    w_op : str
        Wavelet operator to use. Options are "heat", "mexican_hat", "itersine", 
        "simple_tight", "half_cosine_kernel", or "meyer".
    dist : str
        Distance metric to use for constructing the affinity matrix.
    aff_matrix : torch.tensor
        Affinity matrix. If None, then the affinity matrix will be calculated.
    """
    print("Constructing adjacency graph.")
    # construct adj graph
    if aff_matrix == None:
        aff_matrix = construct_affinity_matrix(data, kernel=dist, rbf_norm=rbf_norm)
    n_samples = aff_matrix.shape[0]
    
    print("Calculating laplacian and diagonalization.")
    L = compute_laplacian(aff_matrix)
    eigenvals = torch.linalg.eigvalsh(L)

    print("Wavelet calculation started.")
    if w_op == "heat":
        scales = torch.linspace(0.1, 10, n_scales)
        wave_filter = heat_kernel(eigenvals, scales)
    elif w_op == "mexican_hat":
        wave_filter = mexican_hat_kernel(eigenvals, Nf=n_scales)
    elif w_op == "itersine":
        wave_filter = itersine_kernel(eigenvals, Nf=n_scales)
    elif w_op == "meyer":
        wave_filter = meyer_kernel(eigenvals, Nf=n_scales)
    elif w_op == "simple_tight":
        wave_filter = simple_tight_kernel(eigenvals, Nf=n_scales)
    elif w_op == "half_cosine_kernel":
        wave_filter = half_cosine_kernel(eigenvals, Nf=n_scales)
    else:
        raise

    print("Wavelet transformation.")
    # compute wavelet transforms
    chebyshev = compute_cheby_coeff(wave_filter, n_scales, eigenvals, m=approx_order)
    impulse = torch.eye(n_samples, dtype=torch.float)
    wavelet_coefficients = cheby_op(L, chebyshev, eigenvals, impulse, n_samples)
    phi_matrices = wavelet_coefficients.view((n_scales, n_samples, n_samples))

    # normalize matrices
    cleaned = []
    for i in range(n_scales):
        phi_matrices[i].fill_diagonal_(0)
        phi_matrices[i] /= torch.max(torch.abs(phi_matrices[i]))

        if phi_matrices[i].isnan().any():
            continue

        cleaned.append(phi_matrices[i])

    return torch.stack(cleaned)
