# Generalized Lagrangian Networks | 2020
# Miles Cranmer, Sam Greydanus, Stephan Hoyer (...)

import jax.numpy as jnp
import pickle

def wrap_coords(state):
  # wrap generalized coordinates to [-pi, pi]
  return jnp.concatenate([(state[:1] + jnp.pi) % (2 * jnp.pi) - jnp.pi, state[1:]])

def rk4_step(f, x, t, h):
  # one step of Runge-Kutta integration
  k1 = h * f(x, t)
  k2 = h * f(x + k1/2, t + h/2)
  k3 = h * f(x + k2/2, t + h/2)
  k4 = h * f(x + k3, t + h)
  return x + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4)

def radial2cartesian(t1, l1):
  # Convert from radial to Cartesian coordinates.
  x1 = l1 * jnp.sin(t1)
  y1 = -l1 * jnp.cos(t1)
  return x1, y1

def write_to(data, path):
  with open(path, 'wb') as f:
    pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

def read_from(path):
  with open(path, 'rb') as f:
    data = pickle.load(f)
  return data





# Generalized Lagrangian Networks | 2020
# Miles Cranmer, Sam Greydanus, Stephan Hoyer (...)

import jax
import jax.numpy as jnp
from jax import random
import numpy as np # get rid of this eventually
from jax.experimental.ode import odeint
from functools import partial # reduces arguments to function by making some subset implicit

import os, sys
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

from .lnn import solve_dynamics
#HACK
#from .physics import lagrangian_fn, analytical_fn
from .lnn_physics import lagrangian_fn, analytical_fn


#@partial(jax.jit, backend='cpu')
def get_trajectory(y0, times, use_lagrangian=False, b=0., **kwargs):
  # frames = int(fps*(t_span[1]-t_span[0]))
  # times = jnp.linspace(t_span[0], t_span[1], frames)
  # y0 = np.array([3*np.pi/7, 3*np.pi/4, 0, 0], dtype=np.float32)
  if use_lagrangian:
    y = solve_dynamics(lagrangian_fn, y0, t=times, is_lagrangian=True, rtol=1e-10, atol=1e-10, b=b, **kwargs)
  else:
    y = odeint(analytical_fn, y0, t=times, rtol=1e-10, atol=1e-10, **kwargs)
  return y

@partial(jax.jit, backend='cpu')
def get_trajectory_lagrangian(y0, times, **kwargs):
  return solve_dynamics(lagrangian_fn, y0, t=times, is_lagrangian=True, rtol=1e-10, atol=1e-10, **kwargs)

@partial(jax.jit, backend='cpu')
def get_trajectory_analytic(y0, times, **kwargs):
    return odeint(analytical_fn, y0, t=times, rtol=1e-10, atol=1e-10, **kwargs)

def get_dataset(seed=0, samples=1, t_span=[0,2000], fps=1, test_split=0.5, **kwargs):
    data = {'meta': locals()}

    # randomly sample inputs
    np.random.seed(seed)

    frames = int(fps*(t_span[1]-t_span[0]))
    times = np.linspace(t_span[0], t_span[1], frames)
    y0 = np.array([3*np.pi/7, 3*np.pi/4, 0, 0], dtype=np.float32)

    xs, dxs = [], []
    vfnc = jax.jit(jax.vmap(analytical_fn))
    for s in range(samples):
      x = get_trajectory(y0, times, **kwargs)
      dx = vfnc(x)
      xs.append(x) ; dxs.append(dx)
        
    data['x'] = jnp.concatenate(xs)
    data['dx'] = jnp.concatenate(dxs)
    data['t'] = times

    # make a train/test split
    split_ix = int(len(data['x']) * test_split)
    split_data = {}
    for k in ['x', 'dx', 't']:
        split_data[k], split_data['test_' + k] = data[k][:split_ix], data[k][split_ix:]
    data = split_data
    return data