import os
import yaml
import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt

from goal_set_planning.sim.robot import LinearPointRobot
from goal_set_planning.sim.rollout import Robot2DRollout

from goal_set_planning.costs.costs_2D import TerminalBBoxPrior
from goal_set_planning.controllers import SteinMPC

from goal_set_planning.util.plotting import plot_trajectory_2D

from util.planar_nav import build_controllers, control_loop_mpc, load_scene, TerminalPathChecker

DT = 0.1
tensor_kwargs = {"device": "cpu", "dtype": torch.float}
EXP_TO_RUN = None  # ["cross_entropy_true", "mmd_prior", "kl_cls", "closest_point", "smooth_knn"]


def run(exp, dmap, robot, x0, rollout, save_iters=False, print_freq=10):
    map_img = dmap.compute_binary_img().cpu().numpy()

    for exp_name in EXP_TO_RUN:
        print("Testing", exp_name)
        outpath = os.path.join("output/exp/qualitative", exp_name)
        os.makedirs(outpath, exist_ok=True)
        data = exp[exp_name]

        # Extract parameters.
        STEPS = int(np.ceil(data["horizon"] / DT))
        MAX_STEPS = data["max_steps"]
        K = data["optim"]["n_particles"]
        CTRL_INIT_SIGMA = data["optim"]["init_sigma"]
        LR = data["optim"]["lr"]
        N_ITERS = data["optim"]["n_iters"]
        PRIOR = data["optim"]["prior"]
        goal_samples = data["goal_samples"]

        log_smooth_box_prior, bbox = None, None
        if PRIOR:
            log_smooth_box_prior = TerminalBBoxPrior(goal_samples, sigma=1e-3, **tensor_kwargs)
            # Also save bounding box for visualization if there is a prior.
            bbox = np.concatenate([log_smooth_box_prior.prior.low.cpu().numpy(),
                                   log_smooth_box_prior.prior.high.cpu().numpy()])
        # Create the controller.
        controller = SteinMPC(
            data["costs"], K, 2, STEPS, rollout_fn=rollout,
            init_cov=CTRL_INIT_SIGMA, sample_mode="best", shift_mode="shift",
            optim_type="adam", optim_params={"lr": LR},
            log_prior=log_smooth_box_prior, weight_fn=data["weight_fn"],
            tensor_args=tensor_kwargs
        )

        # Run the control loop.
        traj = control_loop_mpc(controller, robot, x0, MAX_STEPS,
                                n_iters=N_ITERS, terminal_fn=TerminalPathChecker(thresh=0.05),
                                out_path=outpath, viz=True, save_iters=save_iters,
                                print_freq=print_freq, figsize=(6, 6), stop_early=True,
                                map_img=map_img, lims=dmap.lims, bbox=bbox,
                                p_goal=p_goal, x_goal=goal_samples)

        # Draw the trajectory
        vels = np.sqrt(np.sum(traj[:, 2:] * traj[:, 2:], axis=1))

        x_goal = goal_samples.cpu() if torch.is_tensor(goal_samples) else None
        plt.figure(figsize=(6, 6))
        plot_trajectory_2D(plt.gca(), traj, map_img=map_img, lims=dmap.lims,
                           p_goal=p_goal, x_goal=x_goal, bbox=bbox, vels=vels)

        # Save the final trajectory.
        plt.savefig(os.path.join(outpath, "trajectory.jpg"))


def load_args():
    parser = argparse.ArgumentParser(description="Qualitative planar navigation.")
    parser.add_argument("--scene", type=str, required=True, help="Scene file.")
    parser.add_argument("--config", type=str, default="config/exp/planar_nav_qual.yml", help="Configuration file.")
    parser.add_argument("--exp", type=str, default=None, help="Experiment to run.")
    parser.add_argument("--print-freq", type=int, default=10, help="Iteration frequency for debug information.")
    parser.add_argument("--dt", type=float, default=0.1, help="Timestep.")
    parser.add_argument("--save", action="store_true", help="Save intermediate steps.")

    args = parser.parse_args()

    print("Arguments:")
    print('\n'.join("\t{} : {}".format(k, v) for k, v in args.__dict__.items()))

    return args


if __name__ == '__main__':
    # Check for NaNs in torch.
    # torch.autograd.set_detect_anomaly(True)

    args = load_args()

    # Load parameters.
    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.Loader)

    # Load the scene.
    dmap, x0, p_goal = load_scene(args.scene)

    N_GOALS = 100
    goal_samples = p_goal.sample(N_GOALS)

    robot = LinearPointRobot(x0, dt=DT, tensor_kwargs=tensor_kwargs)  # The simulator model.
    rollout = Robot2DRollout(robot)  # The rollout for the controller.

    config = build_controllers(config, p_goal, goal_samples, x0[:2], dmap, tensor_kwargs=tensor_kwargs)

    if EXP_TO_RUN is None:
        if args.exp is not None:
            EXP_TO_RUN = [args.exp]
        else:
            # If experiments were not specified, set to all of them.
            EXP_TO_RUN = config.keys()

    # Run all experiments.
    run(config, dmap, robot, x0, rollout, save_iters=args.save, print_freq=args.print_freq)
