import jax
import jax.numpy as jnp
import numpy as np # get rid of this eventually
import argparse
from jax import jit
from jax.experimental.ode import odeint
from functools import partial # reduces arguments to function by making some subset implicit

from jax.example_libraries import stax
from jax.example_libraries import optimizers

import os, sys, time

from external_models.lnn import lagrangian_eom_rk4, lagrangian_eom, unconstrained_eom, raw_lagrangian_eom, raw_lagrangian_eom_damped
from external_models.lnn_models import mlp as make_mlp
from external_models.lnn_utils import wrap_coords

from external_models.lnn_hps import learned_dynamics, extended_mlp


class ObjectView(object):
    def __init__(self, d): self.__dict__ = d
    
from external_models.lnn_physics import analytical_fn

vfnc = jax.jit(jax.vmap(analytical_fn))
b = 0.
EXP_PATH = '.'

m = 1.
l = 10. # old was 1.
g = 9.81

def u_vec(xv):
    return torch.column_stack((xv[:,1],-(g/l)*np.sin(xv[:,0])-(b/m)*xv[:,1]))

args = ObjectView({'dataset_size': 200,
 'fps': 10,
 'samples': 100,
 'num_epochs': 80000,
 'seed': 15,
 'loss': 'l1',
 'act': 'softplus',
 'hidden_dim': 600,
 'output_dim': 1,
 'layers': 3,
 'n_updates': 1,
 'lr': 0.001,
 'lr2': 2e-05,
 'dt': 0.1,
 'model': 'gln',
 'batch_size': 512,
 'l2reg': 5.7e-07,
})
# args = loaded['args']
rng = jax.random.PRNGKey(args.seed)

from matplotlib import pyplot as plt

vfnc = jax.jit(jax.vmap(analytical_fn, 0, 0))
minibatch_per = 2000
batch = 512

@jax.jit
def get_derivative_dataset(rng):
    # randomly sample inputs

    y0 = jnp.concatenate([
        jax.random.uniform(rng, (batch*minibatch_per, 2))*2.0*np.pi,
        (jax.random.uniform(rng+1, (batch*minibatch_per, 2))-0.5)*10*2
    ], axis=1)
    
    return y0, vfnc(y0)


best_params = None
best_loss = np.inf
from itertools import product

init_random_params, nn_forward_fn = extended_mlp(args)
from external_models import lnn_hps
lnn_hps.nn_forward_fn = nn_forward_fn
_, init_params = init_random_params(rng+1, (-1, 2))
rng += 1
model = (nn_forward_fn, init_params)
opt_init, opt_update, get_params = optimizers.adam(args.lr)
opt_state = opt_init([[l2/200.0 for l2 in l1] for l1 in init_params])
from jax.tree_util import tree_flatten
from external_models.lnn_hps import make_loss, train
from copy import deepcopy as copy
# train(args, model, data, rng);
from jax.tree_util import tree_flatten

@jax.jit
def loss(params, batch, l2reg):
    state, targets = batch#_rk4
    leaves, _ = tree_flatten(params)
    l2_norm = sum(jnp.vdot(param, param) for param in leaves)
    preds = jax.vmap(
        partial(
            raw_lagrangian_eom if b==0. else raw_lagrangian_eom_damped,
            learned_dynamics(params)))(state)
    return jnp.sum(jnp.abs(preds - targets)) + l2reg*l2_norm/args.batch_size

# @jax.jit
# def normalize_param_update(param_update):
#     new_params = []
#     num_weights = args.hidden_dim**2*3
#     gradient_norm = sum([jnp.sum(l2**2)
#                          for l1 in param_update
#                          for l2 in l1
#                          if len(l1) != 0])/num_weights
# #     gradient_norm = 1 + 
#     for l1 in param_update:
#         if (len(l1)) == 0: new_params.append(()); continue
#         new_l1 = []
#         for l2 in l1:
#             new_l1.append(
#                 l2/gradient_norm
#             )

#         new_params.append(new_l1)
        
#     return new_params

@jax.jit
def update_derivative(i, opt_state, batch, l2reg):
    params = get_params(opt_state)
    param_update = jax.grad(
            lambda *args: loss(*args)/len(batch),
            0
        )(params, batch, l2reg)
#     param_update = normalize_param_update(param_update)
    params = get_params(opt_state)
    return opt_update(i, param_update, opt_state), params


best_small_loss = np.inf
(nn_forward_fn, init_params) = model
iteration = 0
total_epochs = 300
minibatch_per = 2000
train_losses, test_losses = [], []

lr = 1e-3 #1e-3
final_div_factor=1e4

#OneCycleLR:
@jax.jit
def OneCycleLR(pct):
    #Rush it:
    start = 0.2 #0.2
    pct = pct * (1-start) + start
    high, low = lr, lr/final_div_factor
    
    scale = 1.0 - (jnp.cos(2 * jnp.pi * pct) + 1)/2
    
    return low + (high - low)*scale
    
from external_models.lnn import custom_init

opt_init, opt_update, get_params = optimizers.adam(
    OneCycleLR
)

init_params = custom_init(init_params, seed=0)

opt_state = opt_init(init_params)
# opt_state = opt_init(best_params)
bad_iterations = 0
print(lr)

import torch
train_dataset = torch.load(f'{'data'}/true_dataset_train.pth')
val_dataset = torch.load(f'{'data'}/true_dataset_val.pth')
test_dataset = torch.load(f'{'data'}/true_dataset_test.pth')
data_x = train_dataset[:][0].numpy()
data_dx = train_dataset[:][1].numpy()
data_test_x = val_dataset[:][0].numpy()
data_test_dx = val_dataset[:][1].numpy()
import numpy as np
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]
data_x, data_dx = unison_shuffled_copies(data_x, data_dx)
data_test_x, data_test_dx = unison_shuffled_copies(data_test_x, data_test_dx)
data = {
    'x': data_x,
    'dx': data_dx,
    'test_x': data_test_x,
    'test_dx': data_test_dx,
}
data = {
    'x': train_dataset[:][1].numpy(),
    'dx': train_dataset[:][2].numpy(),
    'test_x': val_dataset[:][1].numpy(),
    'test_dx': val_dataset[:][2].numpy(),
}



rng = jax.random.PRNGKey(0)
epoch = 0
#batch_data = get_derivative_dataset(rng)[0][:1000], get_derivative_dataset(rng)[1][:1000]
batch_data = data['x'][:1000], data['dx'][:1000]
print(batch_data[0].shape)

loss(get_params(opt_state), batch_data, 0.0)/len(batch_data[0])

opt_state, params = update_derivative(0.0, opt_state, batch_data, 0.0)
from tqdm import tqdm

for epoch in tqdm(range(epoch, total_epochs)):
    epoch_loss = 0.0
    num_samples = 0
    all_batch_data = data['x'], data['dx']#get_derivative_dataset(rng)
    for minibatch in range(minibatch_per):
        fraction = (epoch + minibatch/minibatch_per)/total_epochs
        batch_data = (all_batch_data[0][minibatch*batch:(minibatch+1)*batch], all_batch_data[1][minibatch*batch:(minibatch+1)*batch])
        rng += 10
        opt_state, params = update_derivative(fraction, opt_state, batch_data, 1e-6)
        cur_loss = loss(params, batch_data, 0.0)
        epoch_loss += cur_loss
        num_samples += batch
    closs = epoch_loss/num_samples
    print('epoch={} lr={} loss={}'.format(
        epoch, OneCycleLR(fraction), closs)
         )
    if closs < best_loss:
        best_loss = closs
        best_params = [[copy(jax.device_get(l2)) for l2 in l1] if len(l1) > 0 else () for l1 in params]
        
import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint
ckpt_dir = f'{EXP_PATH}/true/saved_models'
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
import optax


    
# Save the model
checkpoint_dir = os.path.abspath(f'{EXP_PATH}/true/saved_models')
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
checkpoint_path = os.path.join(checkpoint_dir, 'pendulum_lnn')
checkpoints.save_checkpoint(checkpoint_path, params, None, overwrite=True)
#from models.pendulum import u_vec
name = 'true'
title = 'LNN'
# --------------- Plotting ----------------
new_data = torch.load(os.path.join('data', f'true_dataset_test.pth'))
new_data_xv0 = new_data[:][0].reshape((10,-1,3))[:,0,1:]
xv0 = new_data_xv0.numpy()
x0 = xv0[:,0]
v0 = xv0[:,1]
if not os.path.exists(f'{EXP_PATH}/{name}/plots{title}'):
    os.makedirs(f'{EXP_PATH}/{name}/plots{title}')

from matplotlib import patches
dt = 0.001
points = 10
t_max = 10
steps = int(t_max/dt)+1
xv = torch.zeros((points,steps,2))
u_save = torch.zeros((points,steps,2))
xv[:,0,0] = torch.from_numpy(x0)
xv[:,0,1] = torch.from_numpy(v0)
for i in range(1,steps):
# v[i] = v[i-1] + dt*u(x[i-1])
    xv[:,i,0] = xv[:,i-1,0] + dt*u_vec(xv[:,i-1,:])[:,0]
    xv[:,i,1] = xv[:,i-1,1] + dt*u_vec(xv[:,i-1,:])[:,1]
t_base = np.arange(start=0, stop=t_max+dt, step=dt)

# Number of points for the field
N = 500
xlim = np.pi/2
ylim = 2.
X,Y = np.meshgrid(np.linspace(-xlim,xlim,N),np.linspace(-ylim,ylim,N))
pts = np.vstack([X.reshape(-1),Y.reshape(-1)]).T

#plots the streamplot for the velocity field
plt.figure(figsize=(5,5))
#print(pts)
vel = u_vec(torch.from_numpy(pts))
#print(vel)
U = np.array(vel[:,0].reshape(X.shape))
V = np.array(vel[:,1].reshape(Y.shape))
#mask the outside of the ball

plt.streamplot(X,Y,U,V,density=1,color=U**2 + V**2, linewidth=0.15)

plt.xlim((-xlim,xlim))
plt.ylim((-ylim,ylim))
#add outline for aesthetics
t_base = np.arange(start=0, stop=t_max+dt, step=dt)
from external_models.lnn_utils import get_trajectory
#xv_pred = model.evaluate_trajectory(x0=xv[:,0,:].float(), time_steps=steps).detach().cpu().numpy()
xv_pred = torch.zeros((xv.shape))
xv_iter = torch.zeros((xv.shape))
for i in range(xv.shape[0]):
    xv_pred[i,0,:] = xv[i,0,:]
    in_pts = torch.column_stack((torch.from_numpy(t_base).reshape((-1,1)), torch.tile(torch.from_numpy(xv0[i]), (steps,1)))).float()
    print('Gathering trajectory for point', i) 
    lag = raw_lagrangian_eom_damped if b!=0 else raw_lagrangian_eom
    xv_pred[i,:,:] = torch.from_numpy(np.array(jax.device_get(odeint(partial(lag, learned_dynamics(params)), xv0[i], t_base))))
    
    #for j in range(1,steps):
    #    if j % 100 == 0:
    #        
    #        print('Gathering trajectory for point', i, 'at time step', j)
    #    preds = partial(raw_lagrangian_eom_damped if b!=0 else raw_lagrangian_eom, learned_dynamics(params))(xv_pred[i,j-1,:].numpy())
    #    xv_pred[i,j,:] = torch.from_numpy(np.array(preds))*dt + xv_pred[i,j-1,:]
#xv_iter = model.evaluate_trajectory(x0=xv[:,0,:], time_steps=steps).detach().cpu()

xv = xv.numpy()
for i in range(10):
    plt.plot(xv[i,:,0], xv[i,:,1], color='blue')
    plt.plot(xv_pred[i,:,0], xv_pred[i,:,1], color='red')
    #plt.plot(xv_iter[i,:,0], xv_iter[i,:,1], color='green')
    #plt.legend()
blue_patch = patches.Patch(color='blue', label='True trajectories')
red_patch = patches.Patch(color='red', label='Predicted trajectories')
#green_patch = patches.Patch(color='green', label='Iterative trajectories')
plt.legend(handles=[blue_patch,red_patch])
plt.xlabel(r'Angle: $\theta$')
plt.ylabel(r'Angular speed: $\omega$')

plt.title(f'LNN learning phase trajectories')
plt.savefig(f'{EXP_PATH}/{name}/plots{title}/pendulum_phase_trajectory.png', dpi=300)
    
plt.close()

plt.figure(figsize=(8,5))
t_base = np.arange(start=0, stop=t_max+dt, step=dt)
for i in range(10):
    plt.plot(t_base, xv[i,:,0], color='blue')
    plt.plot(t_base, xv_pred[i,:,0], color='red')
blue_patch = patches.Patch(color='blue', label='True trajectories')
red_patch = patches.Patch(color='red', label='Predicted trajectories')
plt.legend(handles=[blue_patch,red_patch])
plt.xlabel(r'Time: $t$')
plt.ylabel(r'Angle: $\theta$')
plt.title(f'LNN learning time trajectories')
plt.savefig(f'{EXP_PATH}/{name}/plots{title}/pendulum_trajectory.png', dpi=300)
plt.close()

#plots the streamplot for the velocity field
plt.figure(figsize=(5,5))
#print(pts)
vel = jax.vmap(partial(raw_lagrangian_eom_damped if b!=0 else raw_lagrangian_eom, learned_dynamics(params)))(pts)
U = np.array(vel[:,0].reshape(X.shape))
V = np.array(vel[:,1].reshape(X.shape))
#mask the outside of the ball



plt.streamplot(X,Y,U,V,density=1,color=U**2 + V**2, linewidth=0.15)
for i in range(10):
    plt.plot(xv_pred[i,:,0], xv_pred[i,:,1], label=f'trajectory{i}', color='red')
plt.xlim((-xlim,xlim))
plt.ylim((-ylim,ylim))

plt.xlabel(r'Angle: $\theta$')
plt.ylabel(r'Angular speed: $\omega$')

plt.title(f'LNN learning predicted field')
plt.savefig(f'{EXP_PATH}/{name}/plots{title}/predicted_field.png')
plt.close()

plt.figure(figsize=(6,5))
vel_true = u_vec(torch.from_numpy(pts))
#print(vel)
U_true = np.array(vel_true[:,0].reshape(X.shape))
V_true = np.array(vel_true[:,1].reshape(Y.shape))
plt.contourf(X,Y,np.sqrt((U-U_true)**2+(V-V_true)**2),100,cmap='jet')
plt.title('Error in predicted fields')
plt.colorbar()
plt.xlim((-xlim,xlim))
plt.ylim((-ylim,ylim))
plt.xlabel(r'Angle: $\theta$')
plt.ylabel(r'Angular speed: $\omega$')

plt.title(f'LNN learning field error')
plt.savefig(f'{EXP_PATH}/{name}/plots{title}/error_field.png')
plt.close()