from .hnn import *
try:
    from functorch import jacrev, vmap
except:
    pass



class __CHNN__(HNN):
    def first_integrals(self, u):
        raise NotImplementedError

    def time_derivative(self, u):
        batchsize=u.shape[0]
        n = u.shape[-1] // 2
        # batchsize x # first integrals x # states
        M = vmap(jacrev(self.first_integrals))(u)
        Mt = M.transpose(-1, -2)
        # batchsize x # states
        dH = self.grad(u)
        In = torch.eye(n)
        Zn = torch.zeros(n, n)
        S = torch.cat(
            [torch.cat([Zn, In], dim=1),
             torch.cat([-In, Zn], dim=1)], dim=0)
        S = S.unsqueeze(0).tile(batchsize, 1, 1).to(M.device)
        I2n = torch.eye(2 * n).to(M.device)
        SMt = S.bmm(Mt)
        # f = (I-S Mt (M S Mt)^-1 M) S dH
        dH = dH.unsqueeze(-1)
        f = (I2n - SMt.bmm(torch.linalg.solve(M.bmm(SMt), M))).bmm(S).bmm(dH)
        f = f.squeeze(-1)
        return f


class CHNN2Pend(__CHNN__):

    def first_integrals(self, u):
        x1, y1, x2, y2, px1, py1, px2, py2 = torch.chunk(u, 8, dim=-1)
        fi1 = x1**2 + y1**2
        fi2 = (x1 - x2)**2 + (y1 - y2)**2
        fi3 = x1 * px1 + y1 * py1
        fi4 = (x1 - x2) * (px1 - px2) + (y1 - y2) * (py1 - py2)
        return torch.cat([fi1, fi2, fi3, fi4], -1)

class CHNN2Body(__CHNN__):

    def first_integrals(self, u):
        # Indeed, this leads to a singular projection matrix, and CHNN does not work.
        x1, y1, x2, y2, px1, py1, px2, py2 = torch.chunk(u, 8, dim=-1)
        fi1 = px1 + px2
        fi2 = py1 + py2
        return torch.cat([fi1, fi2,], -1)
