import torch
import torch.nn as nn


class MaskedGRU(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=False, bidirectional=False)  
        
    def forward(self, x, mask):
        # x: [bs, t, L, dim]
        # mask: [bs, t, L]
        
        x = x.transpose(0, 1)  # [t, bs, L, dim]
        t, bs, L, dim = x.shape
        
        # x = x.reshape(t, bs*L, dim)
        x = x.reshape(t, bs*L, dim)
        output, _ = self.gru(x)
        
        output = output.reshape(t, bs, L, -1)
        output = output.transpose(0, 1)  # [bs, t, L, hidden_dim]

        mask = mask.unsqueeze(-1)  # [bs, t, L, 1]
        output = output * mask.float()
        
        return output
