"""
Analyze the swingup result.
"""

from omegaconf import OmegaConf, DictConfig
import os

import hydra
import numpy as np
import torch

from examples.cartpole.cartpole_plant import CartpolePlant, CartpoleNNDynamicalSystem
from examples.cartpole.swingup import plot_result
from score_po.trajectory import SSTrajectory, IVPTrajectory, BVPTrajectory
from score_po.dynamical_system import sim_openloop


@hydra.main(config_path="./config", config_name="swingup")
def main(cfg: DictConfig):
    OmegaConf.save(cfg, os.path.join(os.getcwd(), "config.yaml"))
    device = cfg.device
    nn_plant = CartpoleNNDynamicalSystem(
        hidden_layers=cfg.plant_param.hidden_layers, device=device
    )
    nn_plant.load_state_dict(torch.load(cfg.dynamics_load_ckpt))

    plant = CartpolePlant(dt=cfg.plant_param.dt)

    x0 = torch.zeros((4,), device=device)
    x_goal = torch.tensor([0, np.pi, 0, 0], device=device)
    if cfg.single_shooting:
        trj = SSTrajectory(dim_x=4, dim_u=1, T=cfg.trj.T, x0=x0).to(device)
    else:
        trj = BVPTrajectory(dim_x=4, dim_u=1, T=cfg.trj.T, x0=x0, xT=x_goal).to(device)

    trj.load_state_dict(torch.load(cfg.trj.load_ckpt))

    x_lo = torch.tensor(cfg.plant_param.x_lo)
    x_up = torch.tensor(cfg.plant_param.x_up)
    u_max = cfg.plant_param.u_max
    plot_result(trj, nn_plant, plant, x_lo, x_up, u_max, cfg.plant_param.dt, alpha=100, beta=1)

    # Now simulate the plant with u_trj
    u_trj = trj.u_trj.data
    if isinstance(trj, BVPTrajectory):
        x_trj_plan, _ = trj.get_full_trajectory()
        # Compute the error in x_next - f(x, u) using the learned dynamics.
        error_nn = nn_plant.dynamics_batch(x_trj_plan[:-1], u_trj) - x_trj_plan[1:]
        # Compute the error in x_next - f(x, u) using the ground truth dynamics.
        error_plant = plant.dynamics_batch(x_trj_plan[:-1], u_trj) - x_trj_plan[1:]

    x_trj_sim_nn = sim_openloop(nn_plant, x0, u_trj, None)
    x_trj_sim_plant = sim_openloop(plant, x0, u_trj, None)

    pass


if __name__ == "__main__":
    main()
