import torch
import torch.nn as nn

from torch import Tensor

__all__ = [
    'GRU'
]

class GRU(nn.Module):
    def __init__(self,
            input_size: int, hidden_size: int, n_layers: int,
            n_diff_time: int, t_embed_size: int = 128, **kwargs
        ):
        super(GRU, self).__init__()
        self.input_size = input_size

        self.t_embed = nn.Embedding(n_diff_time, t_embed_size)
        
        self.rnn = nn.GRU(
            input_size= 2 * input_size + t_embed_size,
            hidden_size=hidden_size,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=True
        )

        self.read_out = nn.Linear(
            2 * hidden_size,
            input_size
        )

    def forward(self, x: Tensor, t: Tensor, condition: Tensor, grid: Tensor) -> Tensor:
        # t.shape == (batch_size, )
        # x.shape == (batch_size, input_size, n_time, 1)
        # condition.shape == (batch_size, input_size, 1, 1)
        x = x.squeeze(dim=-1).permute(0, 2, 1)
        condition = condition.squeeze(dim=-1).permute(0, 2, 1)

        t_embed = self.t_embed(t)

        input_features = torch.concatenate([
            x,
            condition.expand(x.shape),
            t_embed.reshape(x.size(0), 1, -1).expand(*x.shape[:-1], t_embed.size(-1))
        ], dim=-1)

        return self.read_out(self.rnn(input_features)[0]).permute(0, 2, 1).unsqueeze(dim=-1)

