import torch
import torchdiffeq

from . import common


class ODEFunc(torch.nn.Module):
    def __init__(self, hidden_size, hidden_hidden_size, num_layers):
        super(ODEFunc, self).__init__()

        self._mlp = common.MLP(in_size=1 + hidden_size,
                               out_size=hidden_size,
                               hidden_size=hidden_hidden_size,
                               num_layers=num_layers,
                               tanh=True)

    def forward(self, t, z):
        # t has shape ()
        # z has shape (batch_size, hidden_channels)
        t = t.unsqueeze(0).unsqueeze(0).repeat(z.size(0), 1)
        z = torch.cat([t, z], dim=1)
        return self._mlp(z)


class NeuralODE(torch.nn.Module):
    def __init__(self, in_size, out_size, hidden_size, hidden_hidden_size, num_layers):
        super(NeuralODE, self).__init__()

        self._initial = torch.nn.Linear(in_size, hidden_size)
        self._func = ODEFunc(hidden_size=hidden_size, hidden_hidden_size=hidden_hidden_size, num_layers=num_layers)
        self._readout = torch.nn.Linear(hidden_size, out_size)

    def forward(self, t, context):
        y0 = self._initial(context)
        options = dict(step_size=t[1:].sub(t[:-1]).min())
        out = torchdiffeq.odeint(self._func, y0, t, method='midpoint', options=options).transpose(0, 1)
        return self._readout(out)
