import torch
import torchcde

from . import common


class _CDEFunc(torch.nn.Module):
    def __init__(self, in_size, hidden_size, hidden_hidden_size, num_layers):
        super(_CDEFunc, self).__init__()
        self._in_size = in_size
        self._hidden_size = hidden_size

        self._mlp = common.MLP(in_size=1 + hidden_size,
                               out_size=in_size * hidden_size,
                               hidden_size=hidden_hidden_size,
                               num_layers=num_layers,
                               tanh=True)

    def forward(self, t, z):
        # t has shape ()
        # z has shape (batch_size, hidden_channels)
        t = t.unsqueeze(0).unsqueeze(0).repeat(z.size(0), 1)
        z = torch.cat([t, z], dim=1)
        return self._mlp(z).view(z.size(0), self._hidden_size, self._in_size)


class NeuralCDE(torch.nn.Module):
    def __init__(self, in_size, out_size, hidden_size, hidden_hidden_size, num_layers):
        super(NeuralCDE, self).__init__()

        self._initial = torch.nn.Linear(in_size, hidden_size)
        self._func = _CDEFunc(in_size=in_size, hidden_size=hidden_size, hidden_hidden_size=hidden_hidden_size,
                              num_layers=num_layers)
        self._readout = torch.nn.Linear(hidden_size, out_size)

    def forward(self, seq):
        X = torchcde.LinearInterpolation(seq)
        X0 = X.evaluate(X.interval[0])

        z0 = self._initial(X0)
        options = dict(step_size=X.grid_points[1:].sub(X.grid_points[:-1]).min())
        zt = torchcde.cdeint(func=self._func, X=X, z0=z0, t=X.interval, method='midpoint', options=options)
        zT = zt[:, -1]
        return self._readout(zT)
