import os
import yaml
import argparse

import numpy as np
import matplotlib.pyplot as plt

import torch

from goal_set_planning.sim.robot import LinearPointRobot
from goal_set_planning.sim.rollout import Robot2DRollout

from goal_set_planning.costs.costs_2D import TerminalBBoxPrior
from goal_set_planning.controllers import SteinMPC

from goal_set_planning.util.plotting import plot_trajectory_2D

from util.planar_nav import build_controllers, control_loop_mpc, load_scene, TerminalPathChecker

DT = 0.1
tensor_kwargs = {"device": "cpu", "dtype": torch.float}
EXP_TO_RUN = None  # ["cross_entropy_true", "mmd_prior", "kl_cls"]
SCENES = [
    "circles_gaussian",
    "get_in_box_uni",
    "obstacles_mix",
    "squares_mix",
    "squares_uni",
    "circles_uni",
    "get_in_box_gaussian"
]


def run_planners(exp, dmap, robot, x0, rollout, scene_name,
                 p_goal=None, N=5, viz=False, goal_samples_to_save=None):
    map_img = dmap.compute_binary_img().cpu().numpy()

    for exp_name in EXP_TO_RUN:
        print("--- Planner:", exp_name, "---")

        outpath = os.path.join("output/exp/quantitative", exp_name, scene_name)
        os.makedirs(outpath, exist_ok=True)
        data = exp[exp_name]

        # Extract parameters.
        STEPS = int(np.ceil(data["horizon"] / DT))
        MAX_STEPS = data["max_steps"]
        K = data["optim"]["n_particles"]
        CTRL_INIT_SIGMA = data["optim"]["init_sigma"]
        LR = data["optim"]["lr"]
        N_ITERS = data["optim"]["n_iters"]
        PRIOR = data["optim"]["prior"]
        goal_samples = data["goal_samples"]

        if goal_samples_to_save is not None:
            goal_outpath = os.path.join(outpath, "goal_samples.npy")
            np.save(goal_outpath, goal_samples_to_save.cpu().numpy())

        log_smooth_box_prior, bbox = None, None
        if PRIOR:
            log_smooth_box_prior = TerminalBBoxPrior(goal_samples, sigma=1e-3, **tensor_kwargs)
            # Also save bounding box for visualization if there is a prior.
            bbox = np.concatenate([log_smooth_box_prior.prior.low.cpu().numpy(),
                                   log_smooth_box_prior.prior.high.cpu().numpy()])

        # Run the controller N times.
        for i in range(N):
            print("    Testing", exp_name, "iteration", i)
            outpath_i = os.path.join(outpath, str(i))

            if not os.path.exists(outpath_i):
                os.makedirs(outpath_i)

            # Some costs need to be reset.
            data["costs"].reset()

            # Create the controller.
            controller = SteinMPC(
                data["costs"], K, 2, STEPS,
                rollout_fn=rollout, shift_mode="shift",
                init_cov=CTRL_INIT_SIGMA, sample_mode="best",
                optim_type="adam", optim_params={"lr": LR},
                log_prior=log_smooth_box_prior, weight_fn=data["weight_fn"],
                tensor_args=tensor_kwargs
            )

            # Run the control loop.
            traj = control_loop_mpc(controller, robot, x0, MAX_STEPS,
                                    n_iters=N_ITERS, terminal_fn=TerminalPathChecker(thresh=0.05),
                                    out_path=outpath_i, viz=viz, save_iters=False,
                                    print_freq=10, figsize=(6, 6), save_data=True,
                                    map_img=map_img, lims=dmap.lims, bbox=bbox,
                                    p_goal=p_goal, x_goal=goal_samples)

            # Draw the trajectory.
            vels = np.sqrt(np.sum(traj[:, 2:] * traj[:, 2:], axis=1))
            x_goal = goal_samples.cpu() if torch.is_tensor(goal_samples) else None

            if viz:
                plt.figure(0, figsize=(6, 6))
                plot_trajectory_2D(plt.gca(), traj, map_img=map_img, lims=dmap.lims,
                                   p_goal=p_goal, x_goal=x_goal, bbox=bbox, vels=vels)

                # Save the final trajectory.
                plt.savefig(os.path.join(outpath_i, "trajectory.jpg"))

                # Close any figures that may have been open.
                plt.close('all')


def run_scenes(args):
    for scene in SCENES:
        print("***********************************\n")
        print("  Scene:", scene)
        print("\n***********************************")
        # Load the scene.
        scene_path = os.path.join("config/exp/scenes", scene + ".yml")
        dmap, x0, p_goal = load_scene(scene_path)

        N_GOALS = 100
        goal_samples = p_goal.sample(N_GOALS)

        robot = LinearPointRobot(x0, dt=DT, tensor_kwargs=tensor_kwargs)  # The simulator model.
        rollout = Robot2DRollout(robot)  # The rollout for the controller.

        # Load parameters.
        with open(args.config, 'r') as f:
            config = yaml.load(f, Loader=yaml.Loader)

        # Construct costs.
        config = build_controllers(config, p_goal, goal_samples, x0[:2], dmap, tensor_kwargs=tensor_kwargs)

        run_planners(config, dmap, robot, x0, rollout, scene,
                     p_goal=p_goal, N=args.num_runs, viz=args.save,
                     goal_samples_to_save=goal_samples)


def load_args():
    parser = argparse.ArgumentParser(description="Qualitative planar navigation.")
    parser.add_argument("-n", "--num-runs", type=int, default=5, help="Number of runs for each controller.")
    parser.add_argument("--config", type=str, default="config/exp/planar_nav_metrics.yml", help="Configuration file.")
    parser.add_argument("--exp", type=str, default=None, help="Experiment to run.")
    parser.add_argument("--dt", type=float, default=0.1, help="Timestep.")
    parser.add_argument("--save", action="store_true", help="Save images.")

    args = parser.parse_args()

    print("Arguments:")
    print('\n'.join("\t{} : {}".format(k, v) for k, v in args.__dict__.items()))

    return args


if __name__ == '__main__':
    # Check for NaNs in torch.
    # torch.autograd.set_detect_anomaly(True)

    args = load_args()

    # Load parameters.
    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.Loader)

    if EXP_TO_RUN is None:
        if args.exp is not None:
            EXP_TO_RUN = [args.exp]
        else:
            # If experiments were not specified, set to all of them.
            EXP_TO_RUN = config.keys()

    # Run all experiments.
    run_scenes(args)
