#based on code from: Hamiltonian Neural Networks | 2019 | Sam Greydanus, Misko Dzamba, Jason Yosinski


import autograd
import autograd.numpy as np

import scipy.integrate
solve_ivp = scipy.integrate.solve_ivp

np.random.seed(0)

DIM=2
D=np.diag(np.arange(1,DIM+1))
D=np.round(np.round(np.sqrt(D),1)**2,5)
#print(D)

S=np.random.normal(size=(DIM,DIM))
for i in range(DIM):
    for j in range(i):
        S[i]=S[i]-np.dot(S[i],S[j])*S[j]

    S[i]=S[i]/np.linalg.norm(S[i])



A=np.matmul(np.transpose(S),np.matmul(D,S))

def matrices():
    return S,D

def potential_energy(coords):
    q, p= np.split(coords,2)
    return np.matmul(q,np.matmul(A,q))

def kinetic_energy(coords):
    q, p= np.split(coords,2)
    return np.dot(p,p)

def total_energy(coords):
    return potential_energy(coords)+kinetic_energy(coords)


def hamiltonian_fn(coords):
    q=np.zeros(DIM)
    p=np.zeros(DIM)
    q, p= np.split(coords,2)

    H = np.dot(p,p) + np.matmul(q,np.matmul(A,q)) # spring hamiltonian (linear oscillator)
    return H

def dynamics_fn(t, coords):
    dcoords = autograd.grad(hamiltonian_fn)(coords)
    dqdt, dpdt = np.split(dcoords,2)

    S = np.concatenate([dpdt, -dqdt], axis=-1)
    return S

def get_trajectory(t_span=[0,10], timescale=50, radius=None, y0=None, **kwargs):
    t_eval = np.linspace(t_span[0], t_span[1], int(timescale*(t_span[1]-t_span[0])))
    
    # get initial state
    if y0 is None:
        y0 = np.random.randn(2*DIM)*2-1
    if radius is None:
        radius = (np.random.rand()*0.9 + 0.1)*DIM # sample a range of radii
    y0 = y0 / np.sqrt((y0**2).sum()) * radius ## set the appropriate radius
    spring_ivp = solve_ivp(fun=dynamics_fn, t_span=t_span, y0=y0, t_eval=t_eval, rtol=1e-10, **kwargs)
    q=spring_ivp['y'][:DIM]
    p=spring_ivp['y'][DIM:]

    dydt = [dynamics_fn(None, y) for y in spring_ivp['y'].T]
    dydt = np.stack(dydt).T
    dqdt=np.zeros(DIM)
    dpdt=np.zeros(DIM)
    dqdt,dpdt = np.split(dydt,2)

    return q, p, dqdt, dpdt, t_eval

def get_dataset(seed=0, samples=100, test_split=0.5, **kwargs):
    data = {'meta': locals()}

    # randomly sample inputs

    np.random.seed(seed)
    xs, dxs = [], []
    for s in range(samples):
        #print(get_trajectory(**kwargs))
        x,y, dx, dy, t = get_trajectory(**kwargs)
        xs.append( np.concatenate( [x,y]).T )
        dxs.append( np.concatenate( [dx,dy]).T )
        
    data['x'] = np.concatenate(xs)
    data['dx'] = np.concatenate(dxs).squeeze()

    # make a train/test split
    split_ix = int(len(data['x']) * test_split)
    split_data = {}
    for k in ['x', 'dx']:
        split_data[k], split_data['test_' + k] = data[k][:split_ix], data[k][split_ix:]
    data = split_data
    return data

def get_field(xmin=-1.2, xmax=1.2, ymin=-1.2, ymax=1.2, gridsize=20):
    field = {'meta': locals()}

    # meshgrid to get vector field
    b, a = np.meshgrid(np.linspace(xmin, xmax, gridsize), np.linspace(ymin, ymax, gridsize))
    ys = np.stack([b.flatten(), a.flatten()])
    
    # get vector directions
    dydt = [dynamics_fn(None, y) for y in ys.T]
    dydt = np.stack(dydt).T

    field['x'] = ys.T
    field['dx'] = dydt.T
    return field

