from typing import Union
import torch
from omegaconf import DictConfig, OmegaConf

from pado.core.base.transform import PadoTransform
from pado.data.transforms import register_transform

__all__ = ["SpecAugment"]


@register_transform("SpecAugment")
class SpecAugment(PadoTransform):

    def __init__(self,
                 num_freq_mask: int = 2,
                 freq_mask_width: Union[int, float] = 20,
                 num_time_mask: int = 2,
                 time_mask_width: Union[int, float] = 100) -> None:
        super().__init__()

        self.num_freq_mask = num_freq_mask
        if isinstance(freq_mask_width, float) and not (0 < freq_mask_width < 1):
            raise ValueError(f"SpecAugment freq mask_width should be in range (0, 1) "
                             f"but got {freq_mask_width}.")
        self.freq_mask_width = freq_mask_width

        self.num_time_mask = num_time_mask
        if isinstance(time_mask_width, float) and not (0 < time_mask_width < 1):
            raise ValueError(f"SpecAugment time mask_width should be in range (0, 1) "
                             f"but got {time_mask_width}.")
        self.time_mask_width = time_mask_width

    @torch.no_grad()
    def forward(self, spec: torch.Tensor) -> torch.Tensor:
        """
        :param spec:        (num_channels, num_features, feature_length)
        :return:
                            (num_channels, num_features, feature_length)
        """
        # time: dim=2, freq: dim=1
        orig_shape = spec.shape
        if spec.ndim == 2:
            spec = spec.unsqueeze(0)
        if spec.ndim != 3:
            raise ValueError(f"SpecAugment requires input to be 2D/3D, but got {orig_shape}.")

        ch, freq_len, time_len = spec.shape

        device = spec.device
        mask = torch.zeros(spec.shape, dtype=torch.bool, device=device)

        if self.num_freq_mask > 0:
            if isinstance(self.freq_mask_width, float):
                freq_mask_width = max(int(freq_len * self.freq_mask_width), 1)
            else:  # int
                freq_mask_width = self.freq_mask_width

            for ch_idx in range(ch):
                for _ in range(self.num_freq_mask):
                    mask_length = torch.randint(0, freq_mask_width, (1,), device=device).item()
                    mask_position = torch.randint(0, max(1, freq_len - mask_length), (1,), device=device).item()
                    mask[ch_idx, mask_position:mask_position + mask_length, :] = True

        if self.num_time_mask > 0:
            if isinstance(self.time_mask_width, float):
                time_mask_width = max(int(time_len * self.time_mask_width), 1)
            else:  # int
                time_mask_width = self.time_mask_width

            for ch_idx in range(ch):
                for _ in range(self.num_time_mask):
                    mask_length = torch.randint(0, time_mask_width, (1,), device=device).item()
                    mask_position = torch.randint(0, max(1, time_len - mask_length), (1,), device=device).item()
                    mask[ch_idx, :, mask_position:mask_position + mask_length] = True

        spec.masked_fill_(mask, 0)
        return spec

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "SpecAugment":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
