import torch
import librosa
from omegaconf import DictConfig, OmegaConf

from pado.core.base.transform import PadoTransform
from pado.data.transforms import register_transform

__all__ = ["TrimSilence"]


@register_transform("TrimSilence")
class TrimSilence(PadoTransform):
    # librosa wrapper

    def __init__(self,
                 trim_db: float = 60,
                 win_length: int = 2048,
                 hop_length: int = 512):
        super().__init__()

        if trim_db <= 0:
            raise ValueError("Trim DB should be positive value, default: 60.")

        self.trim_db = trim_db
        self.win_length = win_length
        self.hop_length = hop_length

    @torch.no_grad()
    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        :param waveform:        (num_channels, wave_length)
        :return:
                                (num_channels, wave_length - silence in forward and backward)
        """
        if waveform.ndim != 2:
            raise ValueError(f"TrimSilence require 2-dim input, got {tuple(waveform.shape)}.")

        num_channels = waveform.shape[0]
        w = waveform.squeeze(0).cpu().numpy()  # (n,) or (2, n)

        t, _ = librosa.effects.trim(w, self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)
        waveform = torch.tensor(t, dtype=waveform.dtype, device=waveform.device).view(num_channels, -1)
        return waveform

    @classmethod
    def from_config(cls, cfg: DictConfig):
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
