from .modules import *


class __NODEBase__(Module):
    def __init__(self, net: Module):
        super(__NODEBase__, self).__init__()
        self.net = net

    def grad(self, x1):
        raise NotImplementedError

    def discrete_grad(self, x1, x2):
        raise NotImplementedError

    def time_derivative(self, x1):
        return self.forward(x1)


class NODE(__NODEBase__):
    # first-order NODE, du/dt=f(u)
    def forward(self, x1, x2=None):
        assert x2 is None
        return self.net(x1)


class SepSONODE(__NODEBase__):
    # second-order NODE, dq/dt=v, dv/dt=a(q)
    def forward(self, x1, x2=None):
        assert x2 is None
        q, dq = x1.chunk(2, dim=-1)
        ddq = self.net(q)
        output = torch.concat([dq, ddq], dim=-1)
        return output


class SONODE(__NODEBase__):
    # second-order NODE, dq/dt=v, dv/dt=a(q,v)
    def forward(self, x1, x2=None):
        assert x2 is None
        ddq = self.net(x1)
        _, dq = x1.chunk(2, dim=-1)
        output = torch.concat([dq, ddq], dim=-1)
        return output
