import os
import yaml
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt

from goal_set_planning.sim.robot import LinearPointRobot
from goal_set_planning.sim.rollout import Robot2DRollout

from goal_set_planning.distributions import distances
from goal_set_planning.util.misc import euclidean_path_length, euclidean_distance

from goal_set_planning.util.plotting import plot_trajectory_2D, plot_trajectories_2D

from util.planar_nav import load_scene

EXP_PATH = "output/exp/quantitative"
SCENE_PATH = "config/exp/scenes"
tensor_kwargs = {"device": "cpu", "dtype": torch.float}

EXP_TO_RUN = None  # ["closest_point", "cross_entropy_true", "closest_point_dyn"]


def get_ordered_runs(path):
    runs = []
    for run in os.listdir(path):
        if not os.path.isdir(os.path.join(path, run)):
            continue
        runs.append(run)

    runs.sort()
    return runs


def compute_runs(args, path, p_goal, goal_samples, dmap):
    x0 = [1, 1, 0, 0]  # Starting state.

    robot = LinearPointRobot(x0, dt=args.dt, tensor_kwargs=tensor_kwargs)  # The simulator model.
    rollout = Robot2DRollout(robot)  # The rollout for the controller.

    if args.viz:
        map_img = dmap.compute_binary_img().cpu().numpy()

    all_path_dists = []
    all_success = []
    all_final_dists = []
    all_ce = []
    all_cd = []

    runs = get_ordered_runs(path)

    # Some data for evaluations.
    exp = path.split("/")[-2]
    bbox = None
    x_goal = goal_samples
    if exp in ["energy", "mmd_prior", "kl_cls", "smooth_knn"]:
        bbox = np.concatenate([goal_samples.min(axis=0), goal_samples.max(axis=0)])
    elif exp == "closest_point":
        x_goal = goal_samples[euclidean_distance(robot.x[:2], goal_samples).argmin()]
    elif exp == "cross_entropy_true":
        x_goal = None

    for run in runs:
        run_path = os.path.join(path, run)

        # Load additional data.
        with open(os.path.join(run_path, "data.yaml"), 'r') as f:
            data = yaml.load(f, Loader=yaml.Loader)

        # Load the trajectory.
        traj = np.load(os.path.join(run_path, "trajectory.npy"))
        traj = torch.tensor(traj, **tensor_kwargs)
        T = traj.shape[0]
        stop_iter = T

        if "stop_iter" in data.keys():
            stop_iter = data["stop_iter"]

        # Only measure path distance up to the stop iteration.
        path_dist = euclidean_path_length(traj[:stop_iter, :2]).item()
        all_path_dists.append(path_dist)

        # Compute success.
        dist_to_goal = euclidean_distance(traj[stop_iter, :2], goal_samples).min().item()
        all_success.append(dist_to_goal < args.threshold)
        all_final_dists.append(dist_to_goal)

        # Compute metrics over distributions at each trajectory.
        cross_ents = []
        chamfer_dists = []

        for t in range(T - 1):
            particles = np.load(os.path.join(run_path, "particles_{:04d}.npy".format(t)))
            particles = torch.tensor(particles, **tensor_kwargs)
            traj_samples = rollout(traj[t], particles)

            cross_ent = distances.cross_entropy_sample(traj_samples.pos[:, -1, :], p_goal)
            cross_ents.append(cross_ent.item())

            chamfer = distances.chamfer(traj_samples.pos[:, -1, :], goal_samples)
            chamfer_dists.append(chamfer.item())

            # Save visualization, if necessary.
            if args.viz:
                if exp == "closest_point_dyn":
                    x_goal = goal_samples[euclidean_distance(traj[t, :2], goal_samples).argmin()]

                weights = np.zeros(particles.size(0))
                weights[data["path_indices"][t]] = 1  # Mark the selected trajectory.

                plt.figure(0, figsize=(6, 6))
                plt.cla()
                plt.clf()
                plot_trajectories_2D(plt.gca(), traj_samples.pos.cpu().numpy(), weights,
                                     title="Timestep {}".format(t),
                                     map_img=map_img, lims=dmap.lims, bbox=bbox,
                                     p_goal=p_goal, x_goal=x_goal)
                out_file = os.path.join(run_path, "distribution_{:04d}.jpg".format(t))
                plt.savefig(out_file)

        # Plot whole trajectory.
        if args.viz:
            traj = traj.cpu().numpy()
            # Draw the trajectory
            vels = np.sqrt(np.sum(traj[:stop_iter, 2:] * traj[:stop_iter, 2:], axis=1))

            plt.figure(0, figsize=(6, 6))
            plt.cla()
            plt.clf()
            plot_trajectory_2D(plt.gca(), traj[:stop_iter, :], map_img=map_img, lims=dmap.lims,
                               p_goal=p_goal, x_goal=x_goal, vels=vels, bbox=bbox)

            # Save the final trajectory.
            plt.savefig(os.path.join(run_path, "trajectory.jpg"))

            # Save another with no velocities for some visualizations.
            plt.figure(0, figsize=(6, 6))
            plt.cla()
            plt.clf()
            plot_trajectory_2D(plt.gca(), traj[:stop_iter, :], map_img=map_img, lims=dmap.lims,
                               p_goal=p_goal, x_goal=x_goal, bbox=bbox)

            # Save the final trajectory.
            plt.savefig(os.path.join(run_path, "trajectory_no_vels.jpg"))

        all_ce.append(cross_ents)
        all_cd.append(chamfer_dists)

    success_path = os.path.join(path, "success.npy")
    np.save(success_path, np.array(all_success, dtype=bool))
    print("\tPath success saved to:", success_path)

    dist_path = os.path.join(path, "terminal_dists.npy")
    np.save(dist_path, np.array(all_final_dists))
    print("\tTerminal distances saved to:", dist_path)

    path_dists_path = os.path.join(path, "path_dists.npy")
    np.save(path_dists_path, np.array(all_path_dists))
    print("\tPath distances saved to:", path_dists_path)

    cross_ent_path = os.path.join(path, "cross_entropy.npy")
    np.save(cross_ent_path, np.stack(all_ce))
    print("\tCross entropy data saved to:", cross_ent_path)

    chamfer_path = os.path.join(path, "chamfer_dists.npy")
    np.save(chamfer_path, np.stack(all_cd))
    print("\tChamfer distance data saved to:", chamfer_path)


def load_args():
    parser = argparse.ArgumentParser(description="Compute metrics.")
    parser.add_argument("--dt", type=float, default=0.1, help="Timestep.")
    parser.add_argument("--threshold", type=float, default=0.25, help="Threshold for defining task success.")
    parser.add_argument("--viz", action="store_true", help="Visualize trajectories.")

    args = parser.parse_args()

    print("Arguments:")
    print('\n'.join("\t{} : {}".format(k, v) for k, v in args.__dict__.items()))

    return args


if __name__ == '__main__':
    args = load_args()

    # If experiments to run were not specified, run for all of them.
    if EXP_TO_RUN is None:
        EXP_TO_RUN = []
        for exp_name in os.listdir(EXP_PATH):
            if not os.path.isdir(os.path.join(EXP_PATH, exp_name)):
                continue
            EXP_TO_RUN.append(exp_name)

    # EXP_TO_RUN = ["mppi_cross_entropy_true"]

    # Run all experiments.
    for exp_name in EXP_TO_RUN:
        print("\nComputing data for", exp_name)

        exp_path = os.path.join(EXP_PATH, exp_name)
        # Get the list of scenes.
        for scene_name in os.listdir(exp_path):
            scene_path = os.path.join(exp_path, scene_name)
            if not os.path.isdir(scene_path):
                continue

            scene_file = os.path.join(SCENE_PATH, scene_name + ".yml")
            if not os.path.exists(scene_file):
                print(f"Can't find scene file: {scene_file}, skipping.")
                continue

            print(f" -- Computing scene: {scene_name}")
            dmap, x0, p_goal = load_scene(scene_file)

            # Load the goal samples.
            goal_samples = np.load(os.path.join(scene_path, "goal_samples.npy"))

            compute_runs(args, scene_path, p_goal, goal_samples, dmap)
