import torch

from . import ncde
from . import node


class Seq2Seq(torch.nn.Module):
    def __init__(self, split, in_size, context_size, hidden_size, hidden_hidden_size, num_layers):
        super(Seq2Seq, self).__init__()

        self._split = split

        # Use equal-size encoder and decoders for simplicity
        self._encoder = ncde.NeuralCDE(in_size=in_size,
                                       out_size=context_size,
                                       hidden_size=hidden_size,
                                       hidden_hidden_size=hidden_hidden_size,
                                       num_layers=num_layers)
        self._decoder = node.NeuralODE(in_size=context_size,
                                       out_size=in_size,
                                       hidden_size=hidden_size,
                                       hidden_hidden_size=hidden_hidden_size,
                                       num_layers=num_layers)

    def forward(self, t, seq):
        assert len(t) == seq.size(-2)
        split = int(len(t) * self._split)
        inp, out = seq[..., :split, :], seq[..., split:, :]
        t_out = t[split:]

        context = self._encoder(inp)
        pred_out = self._decoder(t_out, context)

        loss = torch.nn.functional.mse_loss(pred_out, out)
        return loss
