import torch
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf

from pado.core.base.transform import PadoTransform
from pado.data.transforms import register_transform

__all__ = ["PadAudio"]


@register_transform("PadAudio")
class PadAudio(PadoTransform):
    """https://github.com/kaldi-asr/kaldi/blob/master/egs/wsj/s5/utils/data/extend_segment_times.py"""

    def __init__(self,
                 sample_rate: int = 16000,
                 left_pad_seconds: float = 0.0,
                 right_pad_seconds: float = 0.0,
                 *, pad_value: float = -1.0) -> None:
        super().__init__()
        self.sample_rate = sample_rate

        self.left_pad_frames = int(sample_rate * left_pad_seconds)
        self.right_pad_frames = int(sample_rate * right_pad_seconds)

        self.pad_value = pad_value

    @torch.no_grad()
    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        :param waveform:        (num_channels, wave_length)
        :return:
                                (num_channels, wave_length + left_pad_frames + right_pad_frames)
        """
        if (self.left_pad_frames <= 0) and (self.right_pad_frames <= 0):
            return waveform

        # wave is normalized to [-1, 1], so padding with -1 may be equal to `empty`,
        # but we decided to use 0 as default.
        waveform = F.pad(waveform, (self.left_pad_frames, self.right_pad_frames),
                         mode='constant', value=self.pad_value)
        return waveform

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "PadAudio":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
