from .modules import *


class __HNNBase__(Module):
    def __init__(self):
        super(__HNNBase__, self).__init__()

    def forward(self, x1, x2=None):
        return self.hamiltonian(x1, x2)

    def hamiltonian(self, x1, x2=None):
        raise NotImplementedError

    def time_derivative(self, x1):
        grad = self.grad(x1)
        Hq, Hp = grad.chunk(2, dim=-1)
        return torch.cat([Hp, -Hq], dim=-1)

    def discrete_time_derivative(self, x1, x2):
        grad = self.discrete_grad(x1, x2)
        Hq, Hp = grad.chunk(2, dim=-1)
        return torch.cat([Hp, -Hq], dim=-1)


class HNN(__HNNBase__):
    def __init__(self, net: Module):
        super(HNN, self).__init__()
        self.net = net

    def hamiltonian(self, x1, x2=None):
        if x2 is None:
            return self.net(x1)
        return self.net(x1, x2)


class SepHNN(__HNNBase__):
    # HNN under separable assumption, namely H(q,p)=V(q)+T(p)
    def __init__(self, netV: Module, netT: Module):
        super(SepHNN, self).__init__()
        self.netV = netV
        self.netT = netT

    def hamiltonian(self, x1, x2=None):
        if x2 is None:
            q, p = x1.chunk(2, dim=-1)
            return self.netV(q) + self.netT(p)
        q1, p1 = x1.chunk(2, dim=-1)
        q2, p2 = x2.chunk(2, dim=-1)
        v1, v2 = self.netV(q1, q2)
        t1, t2 = self.netT(p1, p2)
        return v1 + t1, v2 + t2

    def time_derivative(self, x1):
        q, p = x1.chunk(2, dim=-1)
        return torch.cat([self.netT.grad(p), -self.netV.grad(q)], dim=-1)

    def time_derivative_q(self, p):
        return self.netT.grad(p)

    def time_derivative_p(self, q):
        return -self.netV.grad(q)


class HamiltonianMassSpring(__HNNBase__):
    def __init__(self):
        super(HamiltonianMassSpring, self).__init__()
        pass

    def hamiltonian(self, x1, x2=None):
        if x2 is None:
            q, v, = x1[..., 0], x1[..., 1]
            return 0.5 * q**2 + 0.5 * v**2

        q1, v1 = x1[..., 0], x1[..., 1]
        q2, v2 = x2[..., 0], x2[..., 1]
        qq1, qq2 = Pow2()(q1, q2)
        vv1, vv2 = Pow2()(v1, v2)
        return 0.5 * (qq1 + vv1), 0.5 * (qq2 + vv2)


