
import numpy as np
import torch
import torch.nn.functional as F
import wandb

from scipy import signal
from torch import nn


def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
    """Design prototype filter for PQMF.

    This method is based on `A Kaiser window approach for the design of prototype
    filters of cosine modulated filterbanks`_.

    Args:
        taps (int): The number of filter taps.
        cutoff_ratio (float): Cut-off frequency ratio.
        beta (float): Beta coefficient for kaiser window.

    Returns:
        ndarray: Impluse response of prototype filter (taps + 1,).

    .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
        https://ieeexplore.ieee.org/abstract/document/681427

    """
    # check the arguments are valid
    assert taps % 2 == 0, "The number of taps mush be even number."
    assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."

    # make initial filter
    omega_c = np.pi * cutoff_ratio
    with np.errstate(invalid="ignore"):
        h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
            np.pi * (np.arange(taps + 1) - 0.5 * taps)
        )
    h_i[taps // 2] = np.cos(0) * cutoff_ratio  # fix nan due to indeterminate form

    # apply kaiser window
    w = signal.kaiser(taps + 1, beta)
    h = h_i * w

    return h


class PQMF(torch.nn.Module):
    """PQMF module.

    This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.

    .. _`Near-perfect-reconstruction pseudo-QMF banks`:
        https://ieeexplore.ieee.org/document/258122

    """

    def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0, complement=False):
        """Initilize PQMF module.

        The cutoff_ratio and beta parameters are optimized for #subbands = 4.
        See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.

        Args:
            subbands (int): The number of subbands.
            taps (int): The number of filter taps.
            cutoff_ratio (float): Cut-off frequency ratio.
            beta (float): Beta coefficient for kaiser window.

        """
        super(PQMF, self).__init__()
        if complement:
            cv = 1
        else:
            cv = 0

        # build analysis & synthesis filter coefficients
        h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
        h_analysis = np.zeros((subbands, len(h_proto)))
        h_synthesis = np.zeros((subbands, len(h_proto)))
        for k in range(subbands):
            h_analysis[k] = (
                2
                * h_proto
                * np.cos(
                    (2 * k + 1 + cv)
                    * (np.pi / (2 * subbands))
                    * (np.arange(taps + 1) - (taps / 2))
                    + (-1) ** k * np.pi / 4
                )
            )
            h_synthesis[k] = (
                2
                * h_proto
                * np.cos(
                    (2 * k + 1 + cv)
                    * (np.pi / (2 * subbands))
                    * (np.arange(taps + 1) - (taps / 2))
                    - (-1) ** k * np.pi / 4
                )
            )

        if complement:
            ha0 = np.sqrt(2) * h_proto * np.cos(np.pi / 4)
            h_analysis[-1] *= 1. / np.sqrt(2)
            h_analysis = np.concatenate([ha0[None, :], h_analysis], axis=0)
            hs0 = np.sqrt(2) * h_proto * np.cos(-np.pi / 4)
            h_synthesis[-1] *= 1. / np.sqrt(2)
            h_synthesis = np.concatenate([hs0[None, :], h_synthesis], axis=0)

        # convert to tensor
        analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
        synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)

        # register coefficients as beffer
        self.register_buffer("analysis_filter", analysis_filter)
        self.register_buffer("synthesis_filter", synthesis_filter)

        # filter for downsampling & upsampling
        # updown_filter = torch.zeros((subbands, subbands, subbands)).float()
        # for k in range(subbands):
        #     updown_filter[k, k, 0] = 1.0
        # self.register_buffer("updown_filter", updown_filter)
        # self.subbands = subbands

        # keep padding info
        self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)

    def analysis(self, x):
        """Analysis with PQMF.

        Args:
            x (Tensor): Input tensor (B, 1, T).

        Returns:
            Tensor: Output tensor (B, subbands, T // subbands).

        """
        b, c, l = x.shape
        f = self.analysis_filter.repeat([c, 1, 1])
        x = F.conv1d(self.pad_fn(x), f, groups=x.size(1))
        x = x.chunk(c, dim=1)
        x = torch.stack(x, 2).reshape(b, -1, l)
        return x

    def synthesis(self, x):
        """Synthesis with PQMF.

        Args:
            x (Tensor): Input tensor (B, subbands, T // subbands).

        Returns:
            Tensor: Output tensor (B, 1, T).

        """
        b, _, l = x.shape
        x = x.reshape(b, self.synthesis_filter.size(1), -1, l)
        c = x.size(2)
        x = x.transpose(1, 2).reshape(b, -1, l)
        f = self.synthesis_filter.repeat([c, 1, 1])
        return F.conv1d(self.pad_fn(x), f, groups=c)


pqmf_cutoff_ratios = {3: 0.189, 4: 0.142, 8: 0.071, 16: 0.03552, 24: 0.0237, 32: 0.01777}
pqmf_taps = {3: 48, 4: 62, 8: 124, 16: 246, 24: 368, 32: 490}
