from typing import Optional

import torch
import torch.nn.functional as F


def window_pad(x: torch.Tensor, window_size: int, dim: int=1) -> torch.Tensor:
    """Pad sequence to fit the window size. Can specify any dimension to pad.

    Args:
        x (torch.Tensor): Input to pad
        window_size (int): length to pad to
        dim (int, optional): dimension to pad. Defaults to 1.

    Returns:
        torch.Tensor: padded tensor
    """
    # Ensure padded sequence length is not smaller than window_size
    pad_right = max(window_size - x.size(dim), 0)
    pad = sum([[pad_right if i == dim else 0, 0] for i in range(0,x.ndim)],[])
    pad.reverse()
    x = F.pad(x, pad, mode="constant", value=0)
    return x
     

def ragged_len(
        x: torch.Tensor, pad_value: Optional[torch.Tensor] = None, positive: bool = True
) -> torch.LongTensor:
    """Get the lengths of ragged token sequences padded with zeroes

    Args:
        x: A batch of ragged token sequences stored in a 2D tensor, where the 1st axis indexes sequences in the
            batch, and the 2nd axis indexes elements in the sequence.
        pad_value: the pad_value to look for.
        positive: if we can assume all other values are positive (only effective when input tokens are 1d)
    """
    num_x = x.size(0)
    length = x.size(1)
    # Reshape to 3d (flatten or unsqueeze)
    x = x.reshape(num_x, length, -1)
    if pad_value is None:
        pad_value = torch.zeros(
            x.size(2), dtype=x.dtype, device=x.device
        )
    else:
        pad_value = pad_value.flatten()
    # If input is 2d (or only 1 embedding dim), then use this more efficient logic
    if positive and pad_value.numel() == 1 and pad_value == 0:
        values, l_in = torch.min(x, dim=1, keepdim=True)
        # Fix sequences where there is no padding
        l_in = torch.where(values == pad_value, l_in, x.size(1)).ravel()
    else:
        values, l_in = torch.max(torch.all(x == pad_value, dim=-1).to(dtype=torch.uint8), dim=1)
        # Fix sequences where there is no padding
        l_in = torch.where(values == 0, x.size(1), l_in)
    return l_in


def conv1d_l_out(
    l_in: torch.IntTensor,
    kernel_size: int,
    stride: int = 1,
    padding: int = 0,
    dilation: int = 1,
) -> torch.LongTensor:
    """Get the length of the output sequence returned by `torch.nn.Conv1d`

    Note:
        See the formula for :math:`L_{out}` in the documentation for `torch.nn.Conv1d`.
    """
    return torch.floor(
        (l_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
    ).int()