# Generalized Lagrangian Networks | 2020
# Miles Cranmer, Sam Greydanus, Stephan Hoyer (...)

import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
from functools import partial
import numpy as np


# Parameters of the pendulum
m = 1. # old was 0.1
l = 10. # old was 1.
g = 9.81
b = 0.5

# unconstrained equation of motion
def unconstrained_eom(model, state, t=None):
  q, q_t = state[0], state[1]
  return model(q, q_t)

# lagrangian equation of motion
def lagrangian_eom(lagrangian, state, t=None):
  q, q_t = state[0], state[1]
  #print(q,q_t)
  #Note: the following line assumes q is an angle. Delete it for problems other than double pendulum.
  #q = q % (2*jnp.pi)
  #q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
  #        @ (jax.grad(lagrangian, 0)(q, q_t)
  #           - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
  q_tt = (1/jax.hessian(lagrangian, 1)(q, q_t)) * (jax.grad(lagrangian, 0)(q, q_t) - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) * q_t)
  dt = 1e-1
  return dt*jnp.array([q_t, q_tt])

# lagrangian equation of motion
def lagrangian_eom_damped(lagrangian, state, t=None):
  q, q_t = state[0], state[1]
  #print(q, q_t)
  #Note: the following line assumes q is an angle. Delete it for problems other than double pendulum.
  q_tt = (1/jax.hessian(lagrangian, 1)(q, q_t)) * (jax.grad(lagrangian, 0)(q, q_t) - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) * q_t - b*q_t*l)
  dt = 1e-1
  return dt*jnp.array([q_t, q_tt])

def raw_lagrangian_eom(lagrangian, state, t=None):
  #print(state.shape)
  q, q_t = state[0], state[1]#jnp.split(state)
  #q = q % (2*jnp.pi)
  #q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
  #        @ (jax.grad(lagrangian, 0)(q, q_t)
  #           - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
  q_tt = (1/jax.hessian(lagrangian, 1)(q, q_t)) * (jax.grad(lagrangian, 0)(q, q_t) - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) * q_t)
  return jnp.array([q_t, q_tt])

def raw_lagrangian_eom_damped(lagrangian, state, t=None):
  q, q_t = state[0], state[1]
  #q = q % (2*jnp.pi)
  q_tt = (1/jax.hessian(lagrangian, 1)(q, q_t)) * (jax.grad(lagrangian, 0)(q, q_t) - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) * q_t - b*q_t*l)
  
  return jnp.array([q_t, q_tt])

def lagrangian_eom_rk4(lagrangian, state, n_updates, Dt=1e-1, t=None):
    @jax.jit
    def cur_fnc(state):
        q, q_t = jnp.split(state, 2)
        #q = q % (2*jnp.pi)
        q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
                 @ (jax.grad(lagrangian, 0)(q, q_t)
                 - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
        return jnp.concatenate([q_t, q_tt])
    
    @jax.jit
    def get_update(update):
        dt = Dt/n_updates
        cstate = state + update
        k1 = dt*cur_fnc(cstate)
        k2 = dt*cur_fnc(cstate + k1/2)
        k3 = dt*cur_fnc(cstate + k2/2)
        k4 = dt*cur_fnc(cstate + k3)
        return update + 1.0/6.0 * (k1 + 2*k2 + 2*k3 + k4)
    
    update = 0
    for _ in range(n_updates):
        update = get_update(update)
    return update
    

def solve_dynamics(dynamics_fn, initial_state, is_lagrangian=True, b=0., **kwargs):
  lag = raw_lagrangian_eom_damped if b != 0. else raw_lagrangian_eom
  eom = lag if is_lagrangian else unconstrained_eom

  # We currently run odeint on CPUs only, because its cost is dominated by
  # control flow, which is slow on GPUs.
  @partial(jax.jit, backend='cpu')
  def f(initial_state):
    return odeint(partial(eom, dynamics_fn), initial_state, mxstep=np.inf, **kwargs)
  return f(initial_state)


def custom_init(init_params, seed=0):
    """Do an optimized LNN initialization for a simple uniform-width MLP"""
    import numpy as np
    new_params = []
    rng = jax.random.PRNGKey(seed)
    i = 0
    number_layers = len([0 for l1 in init_params if len(l1) != 0])
    for l1 in init_params:
        if (len(l1)) == 0: new_params.append(()); continue
        new_l1 = []
        for l2 in l1:
            if len(l2.shape) == 1:
                #Zero init biases
                new_l1.append(jnp.zeros_like(l2))
            else:
                n = max(l2.shape)
                first = int(i == 0)
                last = int(i == number_layers - 1)
                mid = int((i != 0) * (i != number_layers - 1))
                mid *= i

                std = 1.0/np.sqrt(n)
                std *= 2.2*first + 0.58*mid + n*last

                if std == 0:
                    raise NotImplementedError("Wrong dimensions for MLP")

                new_l1.append(jax.random.normal(rng, l2.shape)*std)
                rng += 1
                i += 1

        new_params.append(new_l1)
        
    return new_params