from typing import Optional

import torch


class Masking:
    def __init__(
        self, mask_ratio: float = None, single_patch_token_mask: bool = False, patch_len: int = 8, stride: Optional[int] = None
    ):
        """
        Indices with 0 mask are hidden, and with 1 are observed.
        """
        if mask_ratio is not None:
            self.mask_ratio = mask_ratio
        elif single_patch_token_mask == False:
            raise ValueError(
                "mask_ratio be not None or single_patch_token_mask should be set to True"
            )
        
        self.single_patch_token_mask = single_patch_token_mask
        self.patch_len = patch_len
        self.stride = patch_len if stride is None else stride

    @staticmethod
    def convert_seq_to_patch_view(
        mask: torch.Tensor, patch_len: int = 8, stride: Optional[int] = None
    ):
        """
        Converts a 1D binary mask sequence into a patch-level binary mask.

        Each patch is a sliding window of length `patch_len` across the sequence.
        A patch is marked as 1 (i.e., "fully masked") if all elements in the patch are 1.
        Otherwise, it's marked as 0.

        Args:
            mask (torch.Tensor): A binary tensor of shape [batch_size, seq_len],
                where each value should be 0 or 1.
            patch_len (int): The length of each patch (default: 8).
            stride (Optional[int]): The step size between the starts of consecutive patches.
                If not provided, defaults to `patch_len` (i.e., non-overlapping patches).

        Returns:
            torch.Tensor: A tensor of shape [batch_size, n_patches], where each element is 1
                if the corresponding patch in the input mask is fully masked (all 1s), or 0 otherwise.

        Example:
            Input: mask = [[1, 1, 1, 1, 0, 0, 1, 1]], patch_len=4, stride=2
            Patches: [1, 1, 1, 1], [1, 1, 0, 0], [0, 0, 1, 1]
            Output: [[1, 0, 0]]
        """
        stride = patch_len if stride is None else stride
        mask = mask.unfold(dimension=-1, size=patch_len, step=stride)
        # mask : [batch_size x n_patches x patch_len]
        return (mask.sum(dim=-1) == patch_len).long()

    @staticmethod
    def convert_patch_to_seq_view(
        mask: torch.Tensor,
        patch_len: int = 8,
    ):
        """
        Expands a patch-level binary mask back to a sequence-level binary mask.

        This is the inverse operation of `convert_seq_to_patch_view`, assuming non-overlapping patches.
        Each patch value is repeated `patch_len` times to reconstruct the original resolution.

        Args:
            mask (torch.Tensor): A binary tensor of shape [batch_size, n_patches],
                where each element represents whether a patch is fully masked (1) or not (0).
            patch_len (int): The length of each patch (default: 8). Each patch value will be
                repeated `patch_len` times in the output sequence.

        Returns:
            torch.Tensor: A binary tensor of shape [batch_size, seq_len],
                where `seq_len = n_patches * patch_len`. Each patch is expanded back to a flat
                sequence by repeating its value `patch_len` times.

        Example:
            Input: mask = [[1, 0, 1]], patch_len=2
            Output: [[1, 1, 0, 0, 1, 1]]
        """
        return mask.repeat_interleave(patch_len, dim=-1)

    def generate_mask(self, x: torch.Tensor, input_mask: Optional[torch.Tensor] = None):
        """
        Input:
            x : torch.Tensor of shape
            [batch_size x n_channels x n_patches x patch_len] or
            [batch_size x n_channels x seq_len]
            input_mask: torch.Tensor of shape [batch_size x seq_len] or
            [batch_size x n_patches]
        Output:
            mask : torch.Tensor of shape [batch_size x seq_len]
        """
        if x.ndim == 4:
            return self._mask_patch_view(x, input_mask=input_mask)
        elif x.ndim == 3:
            return self._mask_seq_view(x, input_mask=input_mask)

    def _mask_patch_view(self, x, input_mask=None):
        """
        Input:
            x : torch.Tensor of shape
            [batch_size x n_channels x n_patches x patch_len]
            input_mask: torch.Tensor of shape [batch_size x seq_len]
        Output:
            mask : torch.Tensor of shape [batch_size x n_patches]
        """
        # input_mask denotes where in the input seq is padding (0s) and where is actual data (1s)
        # input_mask : [batch_size, n_patches]
        input_mask = self.convert_seq_to_patch_view(
            input_mask, self.patch_len, self.stride
        )
        # x : [batch_size, n_channels, n_patches, patch_len]
        batch_size, _, n_patches, _ = x.shape
        # mask: [batch_size x n_patches]
        

        if self.single_patch_token_mask:
            mask = torch.ones([batch_size, n_patches], device=x.device)
            # Mask exactly one observed patch per sample
            for i in range(batch_size):
                valid_indices = torch.where(input_mask[i] == 1)[0]
                if len(valid_indices) > 0:
                    idx_to_mask = valid_indices[torch.randint(0, len(valid_indices), (1,))]
                    mask[i, idx_to_mask] = 0
            return mask.long()

        else:
            # Original random masking based on mask ratio

            n_observed_patches = input_mask.sum(dim=-1, keepdim=True)  
            # n_observed_patches : [batch_size, 1]
            len_keep = torch.ceil(n_observed_patches * (1 - self.mask_ratio)).long()

            noise = torch.rand(
                batch_size, n_patches, device=x.device
            )  # noise in [0, 1), batch_size x n_channels x n_patches
            noise = torch.where(
                input_mask == 1, noise, torch.ones_like(noise)
            )  # only keep the noise of observed patches

            # Sort noise for each sample
            ids_shuffle = torch.argsort(
                noise, dim=1
            )  # Ascend: small is keep, large is remove
            ids_restore = torch.argsort(
                ids_shuffle, dim=1
            )  # ids_restore: [batch_size x n_patches]

            # Generate the binary mask: 0 is keep, 1 is remove
            # TODO: double check, shoulnd't 0 means it's place that is to be imputed?
            mask = torch.zeros([batch_size, n_patches], device=x.device)
            for i in range(batch_size):
                mask[i, :len_keep[i]] = 1

            # Unshuffle to get the binary mask
            mask = torch.gather(mask, dim=1, index=ids_restore)

            return mask.long()

    
    def _mask_seq_view(self, x, input_mask=None):
        """
        Input:
            x : torch.Tensor of shape
            [batch_size x n_channels x seq_len]
            input_mask: torch.Tensor of shape [batch_size x seq_len]
        Output:
            mask : torch.Tensor of shape [batch_size x seq_len]
        """
        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        mask = self._mask_patch_view(x, input_mask=input_mask)
        return self.convert_patch_to_seq_view(mask, self.patch_len).long()
