import torch
import math

from src.utils import compute_log_scales

def heat_kernel(eigenvals, scale=10, normalize=False):
    """
    Design a filter bank of heat kernels.

    Parameters
    ----------
    G : graph
    scale : float or iterable
        Scaling parameter. When solving heat diffusion, it encompasses both
        time and thermal diffusivity.
        If iterable, creates a filter bank with one filter per value.
    normalize : bool
        Whether to normalize the kernel to have unit L2 norm.
        The normalization needs the eigenvalues of the graph Laplacian.
    """

    try:
        iter(scale)
    except TypeError:
        scale = [scale]
        
    lmax = eigenvals[-1]

    def kernel(x, scale):
        return torch.min(torch.exp(-scale * x / lmax), torch.ones_like(x))

    kernels = []
    for s in scale:
        norm = torch.linalg.norm(kernel(eigenvals, s)) if normalize else 1
        kernels.append(lambda x, s=s, norm=norm: kernel(x, s) / norm)

    return kernels

def mexican_hat_kernel(eigenvals, Nf=6, lpfactor=20, scales=None, normalize=False):
    """
    Design a filter bank of Mexican hat wavelets.

    Parameters
    ----------
    G : graph
    Nf : int
        Number of filters to cover the interval [0, lmax].
    lpfactor : float
        Low-pass factor. lmin=lmax/lpfactor will be used to determine scales.
        The scaling function will be created to fill the low-pass gap.
    scales : array_like
        Scales to be used.
        By default, initialized with :func:`pygsp.utils.compute_log_scales`.
    normalize : bool
        Whether to normalize the wavelet by the factor ``sqrt(scales)``.
    """
    # pdb.set_trace()
    lmax = eigenvals[-1]

    lmin = lmax / lpfactor

    if scales is None:
        scales = compute_log_scales(lmin, lmax, Nf-1)

    if len(scales) != Nf - 1:
        raise ValueError('len(scales) should be Nf-1.')

    def band_pass(x):
        return x * torch.exp(-x)

    def low_pass(x):
        return torch.exp(-x**4)

    kernels = [lambda x: 1.2 * math.exp(-1) * low_pass(x / 0.4 / lmin)]

    for i in range(Nf - 1):

        def kernel(x, i=i):
            norm = torch.sqrt(scales[i]) if normalize else 1
            return norm * band_pass(scales[i] * x)

        kernels.append(kernel)

    return kernels

def itersine_kernel(eigenvals, Nf=6, overlap=2):
    """
    Design an itersine filter bank (tight frame).

    Create an itersine half overlap filter bank of Nf filters.
    Going from 0 to lambda_max.

    Parameters
    ----------
    G : graph
    Nf : int (optional)
        Number of filters from 0 to lmax. (default = 6)
    overlap : int (optional)
        (default = 2)
    """
    lmax = eigenvals[-1]
    mu = torch.linspace(0, lmax, steps=Nf)
    scales = lmax / (Nf - overlap + 1) * overlap

    def kernel(x):
        y = torch.cos(x * torch.pi)**2
        y = torch.sin(0.5 * torch.pi * y)
        return y * ((x >= -0.5) * (x <= 0.5))

    kernels = []
    for i in range(1, Nf + 1):

        def kernel_centered(x, i=i):
            y = kernel(x / scales - (i - overlap / 2) / overlap)
            return y * math.sqrt(2 / overlap)

        kernels.append(kernel_centered)

    return kernels

def meyer_kernel(eigenvals, Nf=6, scales=None):
    lmax = eigenvals[-1]
    if scales is None:
        scales = (4./(3 * lmax)) * torch.pow(2., torch.arange(Nf-2, -1, -1))

    if len(scales) != Nf - 1:
        raise ValueError('len(scales) should be Nf-1.')

    kernels = [lambda x: kernel(scales[0] * x, 'scaling_function')]

    for i in range(Nf - 1):
        kernels.append(lambda x, i=i: kernel(scales[i] * x, 'wavelet'))

    def kernel(x, kernel_type):
        r"""
        Evaluates Meyer function and scaling function

        * meyer wavelet kernel: supported on [2/3,8/3]
        * meyer scaling function kernel: supported on [0,4/3]
        """

        l1 = 2/3.
        l2 = 4/3.  # 2*l1
        l3 = 8/3.  # 4*l1

        def v(x):
            return x**4 * (35 - 84*x + 70*x**2 - 20*x**3)

        r1ind = (x < l1)
        r2ind = (x >= l1) * (x < l2)
        r3ind = (x >= l2) * (x < l3)

        # as we initialize r with zero, computed function will implicitly
        # be zero for all x not in one of the three regions defined above
        r = torch.zeros(x.shape)
        if kernel_type == 'scaling_function':
            r[r1ind] = 1
            r[r2ind] = torch.cos((torch.pi/2) * v(torch.abs(x[r2ind])/l1 - 1))
        elif kernel_type == 'wavelet':
            r[r2ind] = torch.sin((torch.pi/2) * v(torch.abs(x[r2ind])/l1 - 1))
            r[r3ind] = torch.cos((torch.pi/2) * v(torch.abs(x[r3ind])/l2 - 1))
        else:
            raise ValueError('Unknown kernel type {}'.format(kernel_type))

        return r

    return kernels

def half_cosine_kernel(eigenvals, Nf=6):
    if Nf <= 2:
        raise ValueError('The number of filters must be greater than 2.')
    lmax = eigenvals[-1]
    dila_fact = lmax * 3 / (Nf - 2)

    def kernel(x):
        y = torch.cos(2 * torch.pi * (x / dila_fact - .5))
        y = torch.multiply((.5 + .5*y), (x >= 0))
        return torch.multiply(y, (x <= dila_fact))

    kernels = []

    for i in range(Nf):
        def kernel_centered(x, i=i):
            return kernel(x - dila_fact/3 * (i - 2))

        kernels.append(kernel_centered)

    return kernels

def simple_tight_kernel(eigenvals, Nf=6, scales=None):

    def kernel(x, kerneltype):
        l1 = 0.25
        l2 = 0.5
        l3 = 1.

        def h(x):
            return torch.sin(torch.pi*x/2.)**2

        r1ind = (x < l1)
        r2ind = (x >= l1) * (x < l2)
        r3ind = (x >= l2) * (x < l3)

        r = torch.zeros(x.shape)
        if kerneltype == 'sf':
            r[r1ind] = 1.
            r[r2ind] = torch.sqrt(1 - h(4*x[r2ind] - 1)**2)
        elif kerneltype == 'wavelet':
            r[r2ind] = h(4*(x[r2ind] - 1/4.))
            r[r3ind] = torch.sqrt(1 - h(2*x[r3ind] - 1)**2)
        else:
            raise TypeError('Unknown kernel type', kerneltype)

        return r

    lmax = eigenvals[-1]
    if not scales:
        scales = (1./(2.*lmax) * torch.pow(2, torch.arange(Nf-2, -1, -1)))

    if len(scales) != Nf - 1:
        raise ValueError('len(scales) should be Nf-1.')

    kernels = [lambda x: kernel(scales[0] * x, 'sf')]

    for i in range(Nf - 1):
        kernels.append(lambda x, i=i: kernel(scales[i] * x, 'wavelet'))

    return kernels