import time
import numpy as np
import pinocchio as pin

from acyclic_gen import AcyclicGen
from g1_model import create_g1_models

from g1_jump import plan


pin_robot, _ = create_g1_models
urdf = "\path\to\urdf\g1.urdf"

q0 = np.array([
    0, 0, 0.754676, 1, 0, 0, 0,
    0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0,
    0,
    0, 0, 0, 0, 0,
    0, 0, 0, 0, 0])

v0 = pin.utils.zero(pin_robot.model.nv)

motion_time = plan.T
sim_t = 0.0
sim_dt = .001
update_time = 0.0 # sec (time of lag)
lag = int(update_time/sim_dt)

mg = AcyclicGen(pin_robot, urdf)
q, v = q0, v0

## instantiate the motion generator which 
## contains the centroidal dynamics motion generator
## as well subsequent kimenatic optimization
##

mg.update_motion_params(plan, q, sim_t)
mg.params.use_offline_traj = False

plots = {}
x0_ = plan.x0.copy()

randx = np.random.uniform(-0.1, 0.1, 10)
print(randx)

for dt in [0.01, 0.05, 0.1]:
    rollouts = []

    for x in randx:
        plan.dt = dt
        plan.n_col = int(plan.T/plan.dt)
        plan.dt_arr = plan.n_col*[plan.dt,]
        plan.plan_freq = [[plan.T, 0., plan.T]]
        plan.rho = 1e+6
        
        plan.x0 = x0_
        HC_init_x = -0.06 + x
        TC_init_x = .13 + x
        MC_init_in_y = .026 + x
        MC_init_out_y = .026 + x
                            #order [bool cnt on/off, x, y, z, t0, tf]
        plan.cnt_plan = [[[ 1., HC_init_x, -MC_init_out_y,  0., 0.,  plan.stance_time],
                        [ 1., TC_init_x, -MC_init_out_y,  0., 0.,  plan.stance_time],
                        [ 1., HC_init_x, -MC_init_in_y ,  0., 0.,  plan.stance_time],
                        [ 1., TC_init_x, -MC_init_in_y ,  0., 0.,  plan.stance_time],
                        [ 1., HC_init_x,  MC_init_in_y ,  0., 0.,  plan.stance_time],
                        [ 1., TC_init_x,  MC_init_in_y ,  0., 0.,  plan.stance_time],
                        [ 1., HC_init_x,  MC_init_out_y ,  0., 0.,  plan.stance_time],
                        [ 1., TC_init_x,  MC_init_out_y ,  0., 0.,  plan.stance_time]],

                        [[ 0., HC_init_x, -MC_init_out_y,  0., plan.stance_time,  plan.stance_time+plan.flight_time],
                        [ 0., TC_init_x, -MC_init_out_y,  0., plan.stance_time,  plan.stance_time+plan.flight_time],
                        [ 0., HC_init_x, -MC_init_in_y ,  0., plan.stance_time,  plan.stance_time+plan.flight_time],
                        [ 0., TC_init_x, -MC_init_in_y ,  0., plan.stance_time,  plan.stance_time+plan.flight_time],
                        [ 0., HC_init_x,  MC_init_in_y ,  0., plan.stance_time,  plan.stance_time+plan.flight_time],
                        [ 0., TC_init_x,  MC_init_in_y ,  0., plan.stance_time,  plan.stance_time+plan.flight_time],
                        [ 0., HC_init_x,  MC_init_out_y,  0., plan.stance_time,  plan.stance_time+plan.flight_time],
                        [ 0., TC_init_x,  MC_init_out_y,  0., plan.stance_time,  plan.stance_time+plan.flight_time]],

                        [[ 1., HC_init_x, -MC_init_out_y,  0., plan.stance_time+plan.flight_time,  plan.T],
                        [ 1., TC_init_x, -MC_init_out_y,  0., plan.stance_time+plan.flight_time,  plan.T],
                        [ 1., HC_init_x, -MC_init_in_y ,  0., plan.stance_time+plan.flight_time,  plan.T],
                        [ 1., TC_init_x, -MC_init_in_y ,  0., plan.stance_time+plan.flight_time,  plan.T],
                        [ 1., HC_init_x,  MC_init_in_y ,  0., plan.stance_time+plan.flight_time,  plan.T],
                        [ 1., TC_init_x,  MC_init_in_y ,  0., plan.stance_time+plan.flight_time,  plan.T],
                        [ 1., HC_init_x,  MC_init_out_y,  0., plan.stance_time+plan.flight_time,  plan.T],
                        [ 1., TC_init_x,  MC_init_out_y,  0., plan.stance_time+plan.flight_time,  plan.T]]]
        
        
        mg.update_motion_params(plan, q, sim_t)
        xs_plan, us_plan, f_plan = mg.optimize(q, v, np.round(sim_t,3))
        q = xs_plan[0][0:pin_robot.model.nq]
        v = xs_plan[0][pin_robot.model.nq:]
        
        data = [L for L in mg.mp.return_dyn_viol_hist()]
        rollouts.append(data)

    plots[dt] = rollouts

import matplotlib.pyplot as plt

plt.figure()
for idx, (dt, data) in enumerate(plots.items(), start=0):
    print(data)
    R = np.array(data)
    mean_R = R.mean(axis=0)     
    std_R  = R.std( axis=0)
    k = np.arange(len(mean_R))

    plt.plot(k, mean_R, label=f"dt={dt:.2f}s")
    plt.fill_between(k, mean_R - std_R, mean_R + std_R,
                     alpha=0.3)

plt.xlabel('Number of iterations k')
plt.ylabel('Dynamic Violation')
plt.legend()
plt.yscale('log')
plt.grid(True)
plt.show()