import torch
import torch.nn as nn


class TimeTravellingOracle(nn.Module):
    def __init__(self, input_size, capacity, **_):
        super(TimeTravellingOracle, self).__init__()
        self.batch_first = True
        self.input_size = input_size
        self.hidden_size = capacity
        self.projector = nn.Sequential(nn.Linear(input_size, capacity),
                                       nn.ReLU())

    def forward(self, input, state):
        N, T, C = input.shape
        assert T > 2
        # Apply MLP to time shfited version of the time-series
        projected = self.projector(input[:, 1:])
        # Repeat the last time-step to get an output the same shape as the input (this is ignored anyway)
        projected = torch.cat([projected, projected[:, -1:]], dim=1)
        assert tuple(projected.shape) == (N, T, self.hidden_size)
        return projected,
