import torch

from goal_set_planning.distributions import SmoothUniform
from goal_set_planning.util.misc import euclidean_distance

from .base_costs import BaseCost


class Robot2DState(object):
    """Helper class which acts as input to all costs."""

    def __init__(self, pos=None, vel=None, u=None):
        self.pos = pos
        self.vel = vel
        self.u = u

    def state(self):
        return torch.cat([self.pos, self.vel], dim=-1)


"""Stage Costs."""


class QuadraticStateActionCost(BaseCost):
    def __init__(self, c_pos=0., c_vel=0.25, c_u=0.2, **kwargs):
        super(QuadraticStateActionCost, self).__init__(**kwargs)

        self.c_pos = c_pos
        self.c_vel = c_vel
        self.c_u = c_u

    def forward(self, state):
        pos_cost = self.c_pos * (state.pos * state.pos).sum(dim=-1)
        vel_cost = self.c_vel * (state.vel * state.vel).sum(dim=-1)
        u_cost = self.c_u * (state.u * state.u).sum(dim=-1)

        cost = pos_cost.sum(-1) + vel_cost.sum(-1) + u_cost.sum(-1)
        return cost


class SDFCost(BaseCost):
    def __init__(self, map2D, sigma=1., in_sigma=1., compute_in=True, **kwargs):
        super(SDFCost, self).__init__(**kwargs)

        self.map = map2D
        self.sigma = sigma
        self.in_sigma = in_sigma  # Separate sigma for inside the obstacles.
        self.compute_in = compute_in

    def forward(self, state):
        N, T, _ = state.pos.shape

        sdf = self.map.eval_sdf(state.pos.reshape(N * T, -1))
        if self.compute_in:
            val = self.sigma * sdf.clamp(max=0) + self.in_sigma * sdf.clamp(min=0)
        else:
            val = self.sigma * sdf
        obs_cost = torch.exp(val).view(N, T).sum(1)
        return obs_cost


"""Terminal Costs."""


class TerminalPositionCost(BaseCost):
    def __init__(self, goal_pos=0, **kwargs):
        super(TerminalPositionCost, self).__init__(**kwargs)

        self.goal_pos = goal_pos

    def forward(self, state):
        x_term = state.pos[:, -1, :] - self.goal_pos
        cost = (x_term * x_term).sum(dim=-1)
        return cost


class DynamicTerminalPositionCost(BaseCost):
    def __init__(self, goal_samples, **kwargs):
        super(DynamicTerminalPositionCost, self).__init__(**kwargs)

        self.goal_samples = goal_samples
        self.closest_idx = 0

    def forward(self, state):
        x_term = state.pos[:, -1, :] - self.goal_samples[self.closest_idx, ...]
        cost = (x_term * x_term).sum(dim=-1)
        return cost

    def init_iteration(self, state=None, **kwargs):
        goal_dists = euclidean_distance(state[:2], self.goal_samples)
        self.closest_idx = goal_dists.argmin()


class NearestNeighborTerminalPositionCost(BaseCost):
    def __init__(self, goal_samples, **kwargs):
        super(NearestNeighborTerminalPositionCost, self).__init__(**kwargs)

        self.goal_samples = goal_samples

    def forward(self, state):
        goal_dists = euclidean_distance(state.pos[:, -1, :], self.goal_samples)
        cost = goal_dists.min(dim=-1)[0]  # Distance from the trajectory to the nearest neighbor.
        return cost


class TerminalVelocityCost(BaseCost):
    def __init__(self, goal_vel=0, **kwargs):
        super(TerminalVelocityCost, self).__init__(**kwargs)

        self.goal_vel = goal_vel

    def forward(self, state):
        v_term = state.vel[:, -1, :] - self.goal_vel
        cost = (v_term * v_term).sum(dim=-1)
        return cost


class TerminalLogLikelihoodCost(BaseCost):
    def __init__(self, log_likelihood, **kwargs):
        super(TerminalLogLikelihoodCost, self).__init__(**kwargs)

        self.log_likelihood = log_likelihood

    def forward(self, state):
        x_term = state.pos[:, -1, :]
        cost = -self.log_likelihood(x_term)
        return cost


class TerminalSetCost(BaseCost):
    def __init__(self, goal_samples, distance_fn, **kwargs):
        super(TerminalSetCost, self).__init__(**kwargs)

        self.goal_samples = goal_samples
        self.distance_fn = distance_fn

    def forward(self, state):
        x_term = state.pos[:, -1, :]
        N = x_term.size(0)
        cost = self.distance_fn(x_term, self.goal_samples)
        return cost.repeat(N)  # Need one output cost per particle.

    def reset(self, goal_samples=None, **kwargs):
        if goal_samples is not None:
            self.goal_samples = goal_samples.to(**self.tensor_kwargs)

        # Some distance functions need to be reset.
        if hasattr(self.distance_fn, 'reset'):
            self.distance_fn.reset()


"""Priors"""


class TerminalBBoxPrior(object):
    def __init__(self, samples, sigma=0.01, alpha=1., pad=0., device="cpu", dtype=torch.float32):
        self.tensor_kwargs = {"device": device, "dtype": dtype}
        self.alpha = alpha
        bbox_low = samples.min(dim=0)[0] - pad
        bbox_high = samples.max(dim=0)[0] + pad
        self.prior = SmoothUniform(bbox_low, bbox_high, sigma=sigma, **self.tensor_kwargs)

    def __call__(self, state, actions=None):
        """Evaluate the Smooth Box Prior.

        Args:
            state: A Robot2DState object, as generated from a rollout function.
        """
        N, T, _ = state.pos.shape
        x_term = state.pos[:, -1, :]  # Grab the terminal state from each trajectory.
        return self.alpha * self.prior.log_pdf(x_term)
