import torch
from flash_attn.bert_padding import pad_input, unpad_input


class Padder(torch.nn.Module):

    def __init__(self, x: torch.Tensor, padding_mask: torch.Tensor) -> None:
        
        super().__init__()

        self.padding_mask = padding_mask

        # Two different compositions:
        # Base: (batch_size, n_observations, dim)
        # Obs: ((batch_size, n_observations), dim)
        # Both can be used when the model wants to compute linear layers.

        x_o, self.indices_o, self.cu_seqlens_o, self.max_seqlen_in_batch_o = unpad_input(x, ~self.padding_mask)
        
        self.batch_size = x.shape[0]
        self.batch_size_o = x_o.shape[0]

        pass

    
    def base_to_obs(self, x: torch.Tensor) -> torch.Tensor:
        x, _, _, _ = unpad_input(x, ~self.padding_mask)
        return x
    

    def obs_to_base(self, x: torch.Tensor) -> torch.Tensor:
        x = pad_input(x, self.indices_o, self.batch_size, self.max_seqlen_in_batch_o)
        return x
    