import functools
import torch

def filterbank_handler(func):
    @functools.wraps(func)

    def inner(f, Nf, eigenvals, *args, **kwargs):

        if 'i' in kwargs:
            return func(f, Nf, eigenvals, *args, **kwargs)

        elif Nf <= 1:
            return func(f, Nf, eigenvals, *args, **kwargs)

        else:
            output = []
            for i in range(Nf):
                output.append(func(f, Nf, eigenvals, *args, i=i, **kwargs))
            return output

    return inner

@filterbank_handler
def compute_cheby_coeff(f, Nf, eigenvals, m=30, N=None, *args, **kwargs):
    r"""
    Compute Chebyshev coefficients for a Filterbank.

    Parameters
    ----------
    f : Filter
        Filterbank with at least 1 filter
    m : int
        Maximum order of Chebyshev coeff to compute
        (default = 30)
    N : int
        Grid order used to compute quadrature
        (default = m + 1)
    i : int
        Index of the Filterbank element to compute
        (default = 0)

    Returns
    -------
    c : ndarray
        Matrix of Chebyshev coefficients

    """
    i = kwargs.pop('i', 0)

    if not N:
        N = m + 1

    lmax = eigenvals[-1]
    a_arange = [0, lmax]

    a1 = (a_arange[1] - a_arange[0]) / 2
    a2 = (a_arange[1] + a_arange[0]) / 2
    c = torch.zeros(m + 1)

    tmpN = torch.arange(N)
    num = torch.cos(torch.pi * (tmpN + 0.5) / N)
    for o in range(m + 1):
        c[o] = 2. / N * torch.dot(f[i](a1 * num + a2),
                               torch.cos(torch.pi * o * (tmpN + 0.5) / N))

    return c

def cheby_op(L, c, eigenvals, signal, n_samples, **kwargs):
    r"""
    Chebyshev polynomial of graph Laplacian applied to vector.

    Parameters
    ----------
    G : Graph
    c : ndarray or list of ndarrays
        Chebyshev coefficients for a Filter or a Filterbank
    signal : ndarray
        Signal to filter

    Returns
    -------
    r : ndarray
        Result of the filtering

    """
    if type(c) != list:
        c = torch.atleast_2d(c)
        
    Nscales = len(c)
    M = c[0].shape[0]
    
    if M < 2:
        raise TypeError("The coefficients have an invalid shape")

    # thanks to that, we can also have 1d signal.
    try:
        Nv = signal.shape[1]
        r = torch.zeros((n_samples * Nscales, Nv))
    except IndexError:
        r = torch.zeros((n_samples * Nscales))

    lmax = eigenvals[-1]
    a_arange = [0, lmax]

    a1 = float(a_arange[1] - a_arange[0]) / 2.
    a2 = float(a_arange[1] + a_arange[0]) / 2.

    twf_old = signal
    twf_cur = ((L @ signal) - a2 * signal) / a1

    tmpN = torch.arange(n_samples, dtype=int)
    for i in range(Nscales):
        r[tmpN + n_samples*i] = 0.5 * c[i][0] * twf_old + c[i][1] * twf_cur

    factor = 2/a1 * (L - a2 * torch.eye(n_samples))
    for k in range(2, M):
        twf_new = factor @ twf_cur - twf_old
        for i in range(Nscales):
            r[tmpN + n_samples*i] += c[i][k] * twf_new

        twf_old = twf_cur
        twf_cur = twf_new

    return r