from typing import Optional
import math
import torch
import torch.nn.functional as F
import torch.cuda.amp as amp
from omegaconf import DictConfig, OmegaConf
import librosa

from pado.core.base.transform import PadoTransform
from pado.data.transforms import register_transform

__all__ = ["MelFilterBank"]


@register_transform("MelFilterBank")
class MelFilterBank(PadoTransform):
    """
    MFFB extraction.
    https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper/common/features.py#L192
    """

    def __init__(self,
                 sample_rate: int,
                 win_length: int,
                 hop_length: int,
                 n_mels: int,
                 n_fft: Optional[int] = None,
                 window_type: str = "hann",
                 pre_emphasize: float = 0.97,
                 f_min: int = 0,
                 f_max: Optional[int] = None,
                 log: bool = True,
                 remove_dc: bool = False,
                 fft_power: bool = True,
                 normalize_audio: bool = False,
                 normalize_feature: bool = True,
                 dither: float = 0.0,
                 pad_end: int = 0,
                 pad_to_max: Optional[int] = None,
                 frame_stack: int = 1, *,
                 top_db: Optional[float] = None):
        super().__init__()
        self.sample_rate = sample_rate
        self.win_length = win_length
        self.hop_length = hop_length
        self.n_mels = n_mels

        if n_fft is None:
            n_fft = 2 ** math.ceil(math.log2(self.win_length))
        self.n_fft = n_fft

        window_type = window_type.lower()
        if window_type == "hann":
            window_func = torch.hann_window
        elif window_type == "hamming":
            window_func = torch.hamming_window
        elif window_type == "blackman":
            window_func = torch.blackman_window
        elif window_type == "bartlett":
            window_func = torch.bartlett_window
        else:
            raise ValueError(f"MelFilterBank windowing type {window_type} is not supported.")
        self.window_func = window_func

        self.pre_emphasize = pre_emphasize
        self.log = log
        self.dither = dither

        self.remove_dc = remove_dc
        self.fft_power = fft_power
        self.normalize_audio = normalize_audio
        self.normalize_feature = normalize_feature
        self.top_db = top_db

        if f_max is None:
            f_max = sample_rate // 2

        self.window_kernel = window_func(win_length, periodic=False).to(torch.float32)
        self.mel_filter_banks = torch.tensor(librosa.filters.mel(
            sample_rate, n_fft, n_mels=n_mels, fmin=f_min, fmax=f_max), dtype=torch.float32).view(1, n_mels, -1)
        # (n_mels, n_fft // 2 + 1) -> (1, n_mels, n_fft // 2 + 1)
        self.pad_end = pad_end

        if pad_to_max is not None:
            if (pad_to_max <= 0) or (pad_to_max % frame_stack != 0):
                raise ValueError("MelFilterBank pad_to_max is invalid.")
        self.pad_to_max = pad_to_max
        self.frame_stack = frame_stack

    @torch.no_grad()
    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Mel FilterBank extraction.
        This does not support batch-wise computation.

        :param waveform:        (num_channels, wave_length) in [-1, 1] range
        :return:
                mel-feature:    (num_channels, num_mels, num_windows)
        """
        waveform = waveform.float()

        if self.dither > 0:
            waveform.add_(torch.randn_like(waveform), alpha=self.dither)

        if self.remove_dc:
            waveform = self._remove_dc_offset(waveform)
        if self.normalize_audio:
            waveform = self._normalize_audio(waveform)

        if self.pre_emphasize > 0:
            # this impl. drops first frame
            waveform = waveform[..., 1:] - self.pre_emphasize * waveform[..., :-1]

        with amp.autocast(enabled=False):
            feat = torch.stft(waveform, self.n_fft, self.hop_length, self.win_length,
                              window=self.window_kernel, center=True, pad_mode="reflect", return_complex=False)
            # (num_channels, n_fft // 2 + 1, num_windows, 2)

            feat = feat.pow(2).sum(-1)  # (num_channels, n_fft // 2 + 1, num_windows), power
            if not self.fft_power:
                feat = feat.sqrt()  # energy
            feat = torch.matmul(self.mel_filter_banks, feat)  # (num_channels, n_mels, num_windows)

            if self.log:  # log mel frequency (amplitude-to-DB)
                feat = self._amplitude_to_db(feat, top_db=self.top_db)

            if self.normalize_feature:  # time-domain normalize
                feat = self._normalize_feature(feat)

        if (self.pad_end > 0) or (self.frame_stack > 1) or (self.pad_to_max is not None):  # additional dummy windows at the end
            feat_length = feat.shape[-1]

            if self.pad_to_max is not None:
                pad_length = self.pad_to_max - feat_length
                if pad_length < 0:
                    raise ValueError(f"Pad max length {self.pad_to_max} is not sufficient for sample {feat_length}.")
            else:
                pad_length = self.pad_end

            if feat_length % self.frame_stack != 0:
                pad_length += self.frame_stack - (feat_length % self.frame_stack)

            feat = F.pad(feat, (0, pad_length), mode="constant", value=0)
            assert feat.shape[-1] % self.frame_stack == 0

        if self.frame_stack > 1:
            c, m, l = feat.shape
            feat = feat.view(c, m, l // self.frame_stack, self.frame_stack)
            feat = feat.transpose(2, 3).reshape(c, m * self.frame_stack, -1)

        return feat

    @staticmethod
    def _remove_dc_offset(waveform: torch.Tensor) -> torch.Tensor:
        m = torch.mean(waveform, dim=-1)
        waveform.sub_(m)
        return waveform

    @staticmethod
    def _amplitude_to_db(feat: torch.Tensor, top_db=None) -> torch.Tensor:
        # https://github.com/speechbrain/speechbrain/blob/master/speechbrain/processing/features.py#L689
        db = 10 * torch.log10(feat.clamp_min_(1e-10))

        if top_db is not None:
            db_lower_bound = db.max() - top_db
            db = torch.max(db, db_lower_bound)
        return db

    @staticmethod
    def _normalize_audio(waveform: torch.Tensor) -> torch.Tensor:
        # normalize to [-1, 1] range
        gain = 1 / torch.amax(torch.abs(waveform), dim=-1).add_(1e-8)
        waveform = waveform * gain
        return waveform

    @staticmethod
    def _normalize_feature(feat) -> torch.Tensor:
        # feat: (num_channels, n_mels, num_windows)
        mean = torch.mean(feat, dim=-1, keepdim=True)
        var = torch.var(feat, dim=-1, keepdim=True, unbiased=False)
        inv_std = torch.rsqrt(var.add_(1e-5))
        feat = (feat - mean) * inv_std
        return feat

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "MelFilterBank":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
