from typing import Optional, Union
import torch

from pado.core import PadoModule
from pado.nn.functional import gradient_scale

__all__ = ["Add", "Mul", "Transpose", "TimeGradientScale"]


class Add(PadoModule):

    def __init__(self):
        super().__init__()

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return torch.add(x1, x2)


class Mul(PadoModule):

    def __init__(self):
        super().__init__()

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return torch.mul(x1, x2)


class Transpose(PadoModule):

    def __init__(self, dim0: int, dim1: int):
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.transpose(self.dim0, self.dim1).contiguous()


class TimeGradientScale(PadoModule):

    def __init__(self, dim: int = 1):
        super().__init__()
        self.dim = dim

    def forward(self,
                x: torch.Tensor,
                lengths: Union[int, Optional[torch.Tensor]]) -> torch.Tensor:
        """
        :param x:           (batch_size, seq_length, ..._
        :param lengths:     (batch_size,)
        :return:
        """
        batch_size = x.shape[0]

        if lengths is None:
            seq_len = x.shape[self.dim]
            lengths = torch.ones(batch_size, dtype=torch.long, device=x.device).fill_(seq_len)
        elif isinstance(lengths, int):
            lengths = torch.ones(batch_size, dtype=torch.long, device=x.device).fill_(lengths)

        if lengths.shape[0] != batch_size:
            raise ValueError(f"TimeGradientScale input shape {x.shape} but length {lengths.shape}.")

        scale_shape = (batch_size,) + (1,) * (x.ndim - 1)
        scale = 1 / (lengths + 1e-5)
        scale = scale.view(*scale_shape)

        return gradient_scale(x, scale)
