import os
import time
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt

from goal_set_planning.costs.costs_2D import (
    QuadraticStateActionCost,
    SDFCost,
    TerminalPositionCost,
    DynamicTerminalPositionCost,
    NearestNeighborTerminalPositionCost,
    TerminalVelocityCost,
    TerminalLogLikelihoodCost,
    TerminalSetCost,
    TerminalBBoxPrior
)
from goal_set_planning.costs import CostSum

from goal_set_planning.sim.map import DiffMap
import goal_set_planning.distributions as dist
from goal_set_planning.distributions import distances
from goal_set_planning.stein.kernels import RBFMedianKernel
from goal_set_planning.util.misc import euclidean_distance

from goal_set_planning.util.plotting import (
    plot_trajectory_2D, plot_trajectories_2D,
    trajectory_files_to_figures
)


class TerminalPathChecker(object):
    def __init__(self, window=5, thresh=0.05):
        self.window = window
        self.thresh = thresh

        self.dists = []

    def __call__(self, state, goal_samples=None):
        robot_pos = state.pos[0, :]
        term_pos = state.pos[-1, :]
        dist = euclidean_distance(robot_pos, term_pos)

        self.dists.append(dist.item())

        # If we don't have enough distances yet, return not converged.
        if len(self.dists) < self.window:
            return False

        # Don't store more than window length.
        if len(self.dists) > self.window:
            self.dists.pop(0)

        dist_to_term = np.median(self.dists)

        return dist_to_term <= self.thresh


class QuadraticWeights(object):
    """Quadratic weights function, optionally with smooth box prior, for
    choosing the best trajectory."""

    def __init__(self, quad_params, goal_samples=None, prior_sigma=1e-3,
                 tensor_kwargs={"device": "cpu", "dtype": torch.float}):
        self.tensor_kwargs = tensor_kwargs

        self.quad_cost = QuadraticStateActionCost(**quad_params, tensor_kwargs=tensor_kwargs)
        self.log_smooth_box_prior = None
        if goal_samples is not None:
            self.log_smooth_box_prior = TerminalBBoxPrior(goal_samples, sigma=prior_sigma, **tensor_kwargs)

    def __call__(self, state):
        log_px = -self.quad_cost(state)
        if self.log_smooth_box_prior is not None:
            log_px += self.log_smooth_box_prior(state)
        w = log_px - log_px.max()
        w = torch.exp(w) / torch.exp(w).sum()
        return w.detach()


def build_dynamic_closest(config, goal_samples, dmap, tensor_kwargs={"device": "cpu", "dtype": torch.float}):
    goal_samples = goal_samples[:config["terminal"]["num_goals"], ...]

    costs = [
        QuadraticStateActionCost(**config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        DynamicTerminalPositionCost(goal_samples, tensor_kwargs=tensor_kwargs),
    ]
    weights = [config["quadratic"]["weight"],
               config["obstacle"]["weight"],
               config["terminal"]["weight"]]
    cost = CostSum(costs, weights=weights, tensor_kwargs=tensor_kwargs)

    data = {"costs": cost, "weight_fn": None, "goal_samples": goal_samples.clone()}

    return data


def build_nn_point(config, goal_samples, dmap, tensor_kwargs={"device": "cpu", "dtype": torch.float}):
    goal_samples = goal_samples[:config["terminal"]["num_goals"], ...]

    costs = [
        QuadraticStateActionCost(**config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        NearestNeighborTerminalPositionCost(goal_samples, tensor_kwargs=tensor_kwargs)
    ]
    weights = [config["quadratic"]["weight"],
               config["obstacle"]["weight"],
               config["terminal"]["weight"]]
    cost = CostSum(costs, weights=weights, tensor_kwargs=tensor_kwargs)

    data = {"costs": cost, "weight_fn": None, "goal_samples": goal_samples.clone()}

    return data


def build_mmd(mmd_config, goal_samples, dmap, tensor_kwargs={"device": "cpu", "dtype": torch.float}):
    goal_samples = goal_samples[:mmd_config["terminal"]["num_goals"], ...]
    mmd_kernel = RBFMedianKernel()
    mmd_kernel.set_params(goal_samples)

    kernel_mmd_dist = distances.KernelMMDSampleDistance(mmd_kernel, estimate_params=False)

    mmd_costs = [
        QuadraticStateActionCost(**mmd_config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **mmd_config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        TerminalSetCost(goal_samples, kernel_mmd_dist, tensor_kwargs=tensor_kwargs)
    ]
    mmd_weights = [mmd_config["quadratic"]["weight"],
                   mmd_config["obstacle"]["weight"],
                   mmd_config["terminal"]["weight"]]
    mmd_cost = CostSum(mmd_costs, weights=mmd_weights, tensor_kwargs=tensor_kwargs)

    weight_fn = QuadraticWeights(mmd_config["quadratic"]["params"], goal_samples=goal_samples,
                                 tensor_kwargs=tensor_kwargs)

    data = {"costs": mmd_cost, "weight_fn": weight_fn, "goal_samples": goal_samples.clone()}

    return data


def build_knn(config, goal_samples, dmap, tensor_kwargs={"device": "cpu", "dtype": torch.float}):
    goal_samples = goal_samples[:config["terminal"]["num_goals"], ...]
    knn_dist = distances.SmoothKNNStatistic(alphas=[1])

    costs = [
        QuadraticStateActionCost(**config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        TerminalSetCost(goal_samples, knn_dist, tensor_kwargs=tensor_kwargs)
    ]
    weights = [config["quadratic"]["weight"],
               config["obstacle"]["weight"],
               config["terminal"]["weight"]]
    cost = CostSum(costs, weights=weights, tensor_kwargs=tensor_kwargs)

    weight_fn = QuadraticWeights(config["quadratic"]["params"], goal_samples=goal_samples,
                                 tensor_kwargs=tensor_kwargs)

    data = {"costs": cost, "weight_fn": weight_fn, "goal_samples": goal_samples.clone()}

    return data


def build_energy(config, goal_samples, dmap, tensor_kwargs={"device": "cpu", "dtype": torch.float}):
    goal_samples = goal_samples[:config["terminal"]["num_goals"], ...]
    energy_dist = distances.EnergyStatistic()

    costs = [
        QuadraticStateActionCost(**config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        TerminalSetCost(goal_samples, energy_dist, tensor_kwargs=tensor_kwargs)
    ]
    weights = [config["quadratic"]["weight"],
               config["obstacle"]["weight"],
               config["terminal"]["weight"]]
    cost = CostSum(costs, weights=weights, tensor_kwargs=tensor_kwargs)

    weight_fn = QuadraticWeights(config["quadratic"]["params"], goal_samples=goal_samples,
                                 tensor_kwargs=tensor_kwargs)

    data = {"costs": cost, "weight_fn": weight_fn, "goal_samples": goal_samples.clone()}

    return data


def build_prior_only(config, goal_samples, dmap, tensor_kwargs={"device": "cpu", "dtype": torch.float}):
    costs = [
        QuadraticStateActionCost(**config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
    ]
    weights = [config["quadratic"]["weight"],
               config["obstacle"]["weight"]]
    cost = CostSum(costs, weights=weights, tensor_kwargs=tensor_kwargs)

    data = {"costs": cost, "weight_fn": None, "goal_samples": goal_samples.clone()}

    return data


def build_controllers(config, p_goal, goal_samples, x0, dmap,
                      tensor_kwargs={"device": "cpu", "dtype": torch.float}):
    """
    Baseline 1: Pick the closest point.
    """
    x0 = torch.tensor(x0, **tensor_kwargs)
    # Pick the sample that is the closest to the robot at the initial state.
    goal_dists = euclidean_distance(x0, goal_samples)
    goal_point = goal_samples[goal_dists.argmin(), ...]

    closest_point_config = config["closest_point"]["costs"]
    cross_ent_costs = [
        QuadraticStateActionCost(**closest_point_config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **closest_point_config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        TerminalPositionCost(goal_point, tensor_kwargs=tensor_kwargs)
    ]
    close_pt_weights = [closest_point_config["quadratic"]["weight"],
                        closest_point_config["obstacle"]["weight"],
                        closest_point_config["terminal"]["weight"]]
    closest_pt_cost = CostSum(cross_ent_costs, weights=close_pt_weights, tensor_kwargs=tensor_kwargs)
    config["closest_point"]["costs"] = closest_pt_cost
    config["closest_point"]["weight_fn"] = None
    config["closest_point"]["goal_samples"] = goal_point

    """
    Baseline 2: Dynamic closest point to current state.
    """
    if "closest_point_dyn" in config.keys():
        dynamic_cost_data = build_dynamic_closest(config["closest_point_dyn"]["costs"], goal_samples, dmap,
                                                  tensor_kwargs=tensor_kwargs)
        config["closest_point_dyn"].update(dynamic_cost_data)

    """
    Baseline 3: Nearest neighbor point to the terminal state of each trajectory.
    """
    if "nn_point" in config.keys():
        nn_cost_data = build_nn_point(config["nn_point"]["costs"], goal_samples, dmap,
                                      tensor_kwargs=tensor_kwargs)
        config["nn_point"].update(nn_cost_data)

    """
    Baseline 4: Cross entropy against true distribution.
    """
    cross_ent_config = config["cross_entropy_true"]["costs"]
    cross_ent_costs = [
        QuadraticStateActionCost(**cross_ent_config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **cross_ent_config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        TerminalLogLikelihoodCost(p_goal.log_pdf, tensor_kwargs=tensor_kwargs)
    ]
    ce_weights = [cross_ent_config["quadratic"]["weight"],
                  cross_ent_config["obstacle"]["weight"],
                  cross_ent_config["terminal"]["weight"]]
    cross_ent_cost = CostSum(cross_ent_costs, weights=ce_weights, tensor_kwargs=tensor_kwargs)
    config["cross_entropy_true"]["costs"] = cross_ent_cost
    config["cross_entropy_true"]["weight_fn"] = None
    config["cross_entropy_true"]["goal_samples"] = None

    """
    Method 1: MMD + prior
    """
    if "mmd_prior" in config.keys():
        mmd_cost_data = build_mmd(config["mmd_prior"]["costs"], goal_samples, dmap,
                                  tensor_kwargs=tensor_kwargs)
        config["mmd_prior"].update(mmd_cost_data)

    """
    Method 2: KL Classification.
    """
    kl_cls_config = config["kl_cls"]["costs"]

    goal_samples = goal_samples[:kl_cls_config["terminal"]["num_goals"], ...]
    kl_dist_cls = distances.KLDivergenceClassifier(**kl_cls_config["terminal"]["params"], tensor_kwargs=tensor_kwargs)

    kl_cls_costs = [
        QuadraticStateActionCost(**kl_cls_config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **kl_cls_config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        TerminalSetCost(goal_samples, kl_dist_cls, tensor_kwargs=tensor_kwargs)
    ]
    kl_cls_weights = [kl_cls_config["quadratic"]["weight"],
                      kl_cls_config["obstacle"]["weight"],
                      kl_cls_config["terminal"]["weight"]]
    kl_cls_cost = CostSum(kl_cls_costs, weights=kl_cls_weights, tensor_kwargs=tensor_kwargs)

    config["kl_cls"]["costs"] = kl_cls_cost
    config["kl_cls"]["weight_fn"] = None
    config["kl_cls"]["goal_samples"] = goal_samples.clone()

    """
    Method 3: Smooth KNN.
    """
    if "smooth_knn" in config.keys():
        knn_cost_data = build_knn(config["smooth_knn"]["costs"], goal_samples, dmap,
                                  tensor_kwargs=tensor_kwargs)
        config["smooth_knn"].update(knn_cost_data)

    """
    Method 4: Energy Statistic.
    """
    if "energy" in config.keys():
        ener_cost_data = build_energy(config["energy"]["costs"], goal_samples, dmap,
                                      tensor_kwargs=tensor_kwargs)
        config["energy"].update(ener_cost_data)

    """
    Ablation: Prior only
    """
    if "prior_only" in config.keys():
        prior_cost_data = build_prior_only(config["prior_only"]["costs"], goal_samples, dmap,
                                           tensor_kwargs=tensor_kwargs)
        config["prior_only"].update(prior_cost_data)

    return config


def control_loop_mpc(controller, robot, x0, iters, n_iters=50,
                     save_iters=False, print_freq=10, out_path="output", viz=False,
                     terminal_fn=None, bbox=None, save_data=False, stop_early=False,
                     figsize=(6, 6), map_img=None, lims=None, p_goal=None, x_goal=None):
    # Convert the goal samples to CPU if necessary.
    if torch.is_tensor(x_goal):
        x_goal = x_goal.cpu()

    # Set robot to initial position.
    robot.reset(x0)
    traj = [robot.x.cpu().numpy()]

    # For checking average runtime.
    total_time = []
    path_indices = []
    stop_iter = iters

    for it in range(iters):
        save_ani = it % print_freq == 0 and save_iters
        save_path = os.path.join(out_path, "traj_{:04d}".format(it)) if save_ani else None

        state = robot.x

        # Reset the costs.
        controller.cost_fn.init_iteration(state=state)

        start = time.time()

        # Calculate the next control command.
        shift = 1 if it > 0 else 0
        seq, idx = controller.optimize(state, shift_steps=shift,
                                       n_iters=n_iters, return_idx=True,
                                       save_data=save_path)

        total_time.append(time.time() - start)
        path_indices.append(idx)

        if save_path is not None:
            # Save all the trajectories for this optimization iteration as images.
            trajectory_files_to_figures(save_path, state, controller.rollout,
                                        figsize=figsize, map_img=map_img,
                                        lims=lims, p_goal=p_goal, x_goal=x_goal)

        # Save converged particles for this iteration, if requested.
        if save_data:
            particles_path = os.path.join(out_path, "particles_{:04d}.npy".format(it))
            np.save(particles_path, controller.action_particles().cpu().numpy())

        if viz:
            # Show the converged trajectories for this iteration.
            u = controller.action_particles()
            traj_samples = controller.rollout(robot.x, u)

            weights = np.zeros(controller.num_particles)
            weights[idx] = 1  # Mark the selected trajectory.

            plt.figure(0, figsize=figsize)
            plt.gca().clear()
            plot_trajectories_2D(plt.gca(), traj_samples.pos.cpu().numpy(), weights,
                                 title="Timestep {}".format(it),
                                 map_img=map_img, lims=lims, bbox=bbox,
                                 p_goal=p_goal, x_goal=x_goal)
            out_file = os.path.join(out_path, "distribution_{:04d}.jpg".format(it))
            plt.savefig(out_file)

        # Save the current robot state to the trajectory.
        traj.append(robot.x.cpu().numpy())

        # Check terminal condition.
        if terminal_fn is not None:
            # Get the path associated with the current sequence.
            state_seq = controller.rollout(robot.x, seq.unsqueeze(0))
            if terminal_fn(state_seq):
                stop_iter = min(it, stop_iter)
                if stop_early:
                    print("Terminal condition reached. Iteration:", stop_iter)
                    break

        # Get the next action in the sequence and apply it to the robot.
        u0 = seq[0, :]
        robot.step(u0)

        if it % print_freq == 0:
            print("\tStep:", it)

    print("\tAverage time:", np.mean(total_time))

    # Close plots if necessary.
    if viz:
        plt.close(0)

    traj = np.stack(traj)

    if save_data:
        traj_path = os.path.join(out_path, "trajectory.npy")
        np.save(traj_path, traj)

        run_data = {"stop_iter": stop_iter, "path_indices": path_indices}
        with open(os.path.join(out_path, "data.yaml"), 'w') as f:
            yaml.dump(run_data, f)

    return traj


def load_scene(scene_file, tensor_kwargs={"device": "cpu", "dtype": torch.float}):
    # Scene.
    with open(scene_file, 'r') as f:
        scene_config = yaml.load(f, Loader=yaml.Loader)

    dmap = DiffMap(scene_config["map_path"], tensor_kwargs=tensor_kwargs)

    # Create the goal distribution.
    p_goal = None
    goal_config = scene_config["goal"]["params"]
    dist_type = scene_config["goal"]["type"]
    if dist_type == "mixture":
        means = goal_config["means"]
        sigmas = goal_config["sigmas"]
        p_goal = dist.Mixture(means, sigmas, **tensor_kwargs)
    elif dist_type == "gaussian":
        mean = goal_config["mean"]
        sigma = goal_config["sigma"]
        p_goal = dist.Gaussian(mean, sigma, **tensor_kwargs)
    elif dist_type == "uniform":
        low = goal_config["low"]
        high = goal_config["high"]
        sigma = goal_config["sigma"] if "sigma" in goal_config.keys() else 0.01
        p_goal = dist.SmoothUniform(low, high, sigma=sigma, **tensor_kwargs)
    else:
        raise Exception("ERROR: Unrecognized distribution type: " + str(dist_type))

    x0 = scene_config["start"]  # Starting state.

    return dmap, x0, p_goal
