import torch
import torch.nn as nn

from torch import Tensor

__all__ = [
    'ParaPhyGRU'
]

class ParaPhyGRU(nn.Module):
    def __init__(self,
            input_size: int, hidden_size: int, input_length: int,
            n_layers: int, repara_size: int, 
            n_diff_time: int, t_embed_size: int, bidirectional: bool = True, **kwargs
        ):
        super(ParaPhyGRU, self).__init__()
        self.input_size = input_size
        self.input_length = input_length

        self.t_embed = nn.Embedding(n_diff_time, t_embed_size)
        
        self.rnn = nn.GRU(
            input_size=input_size + t_embed_size,
            hidden_size=hidden_size,
            num_layers=n_layers,
            batch_first=True,
            bidirectional=bidirectional
        )

        latent_hidden_size = 2 * hidden_size if bidirectional else hidden_size
        self.read_out = nn.Linear(
            input_size + latent_hidden_size + t_embed_size,
            input_size
        )

        self.phy_out = nn.Linear(
            input_size + latent_hidden_size + t_embed_size,
            repara_size
        )

    def forward(self, x: Tensor, t: Tensor):
        # t.shape == (batch_size, )
        # x.shape == (batch_size, n_time, input_size)

        batch_size, n_time, input_size = x.shape
        assert input_size == 18
        # pos_mean.shape == (batch_size, n_time, n_system, xyz)
        pos_mean = x[:, :, :9].reshape(batch_size, n_time, 3, 3)
        pos_mean -= pos_mean.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
        x[:, :, :9] = pos_mean.reshape(batch_size, n_time, 9)

        t_embed = self.t_embed(t)

        input_features = torch.concatenate([
            x,
            t_embed.reshape(x.size(0), 1, -1).expand(*x.shape[:-1], t_embed.size(-1))
        ], dim=-1)

        rnn_feature = torch.concatenate([input_features, self.rnn(input_features)[0]], dim=-1)
        return self.read_out(rnn_feature), self.phy_out(rnn_feature)

