from typing import Tuple
import random
import torch
from torchaudio.functional.functional import resample
from omegaconf import DictConfig, OmegaConf

from pado.core.base.transform import PadoTransform
from pado.data.transforms import register_transform

__all__ = ["SpeedPerturbation"]


@register_transform("SpeedPerturbation")
class SpeedPerturbation(PadoTransform):
    # similar to torchaudio.transforms.Resample
    # should called before STFT

    def __init__(self,
                 sample_rate: int = 16000,
                 speeds: Tuple[float, ...] = (1.0,)):
        super(SpeedPerturbation, self).__init__()
        self.sample_rate = sample_rate

        speeds = list(speeds)
        if 1.0 not in speeds:
            speeds.append(1.0)
        self.speeds = sorted(speeds)

    @torch.no_grad()
    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        :param waveform:        (num_channels, wave_length)
        :return:
                                (num_channels, wave_length * (min_speed ~ max_speed))
        """
        s = random.randint(0, len(self.speeds) - 1)  # inclusive
        s = self.speeds[s]

        if s == 1.0:
            return waveform

        new_freq = int(self.sample_rate * s) // 16 * 16  # make multiple of 16
        # 16k -> 14.4k (0.9)
        # 16k -> 17.6k (1.1)

        waveform = resample(waveform, self.sample_rate, new_freq)
        return waveform

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "SpeedPerturbation":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
