import torch
import torch.nn as nn

from torch import Tensor

__all__ = [
    'GRU'
]

class GRU(nn.Module):
    def __init__(self,
            input_size: int, rnn_hidden_size: int, input_length: int, n_layers: int,
            n_diff_time: int, t_embed_size: int, **kwargs
        ):
        super(GRU, self).__init__()
        self.input_size = input_size
        self.input_length = input_length

        self.t_embed = nn.Embedding(n_diff_time, t_embed_size)
        
        self.rnn = nn.GRU(
            input_size=input_size + t_embed_size,
            hidden_size=rnn_hidden_size,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=True
        )

        self.read_out = nn.Linear(
            2 * rnn_hidden_size,
            input_size
        )

    def forward(self, x: Tensor, t: Tensor, *_):
        # t.shape == (batch_size, )
        # x.shape == (batch_size, n_time, input_size)

        t_embed = self.t_embed(t)

        input_features = torch.concatenate([
            x,
            t_embed.reshape(x.size(0), 1, -1).expand(*x.shape[:-1], t_embed.size(-1))
        ], dim=-1)

        return self.read_out(self.rnn(input_features)[0])

