import os
import time
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt

from goal_set_planning.costs.costs_2D import (
    QuadraticStateActionCost,
    SDFCost,
    TerminalPositionCost,
    DynamicTerminalPositionCost,
    NearestNeighborTerminalPositionCost,
    TerminalVelocityCost,
    TerminalLogLikelihoodCost,
    TerminalSetCost,
    TerminalBBoxPrior
)
from goal_set_planning.costs import CostSum

from goal_set_planning.sim.map import DiffMap
import goal_set_planning.distributions as dist
from goal_set_planning.distributions import distances
from goal_set_planning.util.misc import euclidean_distance

from goal_set_planning.util.plotting import (
    plot_trajectory_2D, plot_trajectories_2D,
    trajectory_files_to_figures
)


def build_controllers(config, p_goal, goal_samples, x0, dmap,
                      tensor_kwargs={"device": "cpu", "dtype": torch.float}):
    """
    MPPI 1: Pick the closest point.
    """
    x0 = torch.tensor(x0, **tensor_kwargs)
    # Pick the sample that is the closest to the robot at the initial state.
    goal_dists = euclidean_distance(x0, goal_samples)
    goal_point = goal_samples[goal_dists.argmin(), ...]

    closest_point_config = config["closest_point"]["costs"]
    cross_ent_costs = [
        QuadraticStateActionCost(**closest_point_config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **closest_point_config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        TerminalPositionCost(goal_point, tensor_kwargs=tensor_kwargs)
    ]
    close_pt_weights = [closest_point_config["quadratic"]["weight"],
                        closest_point_config["obstacle"]["weight"],
                        closest_point_config["terminal"]["weight"]]
    closest_pt_cost = CostSum(cross_ent_costs, weights=close_pt_weights, tensor_kwargs=tensor_kwargs)

    config["closest_point"]["costs"] = closest_pt_cost
    config["closest_point"]["goal_samples"] = goal_point

    """
    MPPI 2: Cross entropy against true distribution.
    """
    cross_ent_config = config["cross_entropy_true"]["costs"]
    cross_ent_costs = [
        QuadraticStateActionCost(**cross_ent_config["quadratic"]["params"], tensor_kwargs=tensor_kwargs),
        SDFCost(dmap, **cross_ent_config["obstacle"]["params"], tensor_kwargs=tensor_kwargs),
        TerminalLogLikelihoodCost(p_goal.log_pdf, tensor_kwargs=tensor_kwargs)
    ]
    ce_weights = [cross_ent_config["quadratic"]["weight"],
                  cross_ent_config["obstacle"]["weight"],
                  cross_ent_config["terminal"]["weight"]]
    cross_ent_cost = CostSum(cross_ent_costs, weights=ce_weights, tensor_kwargs=tensor_kwargs)

    config["cross_entropy_true"]["costs"] = cross_ent_cost
    config["cross_entropy_true"]["goal_samples"] = None

    return config
