from goal_set_planning.costs.costs_2D import Robot2DState


class Robot2DRollout(object):
    def __init__(self, robot, cost=None):
        self.robot = robot
        # self.cost = cost

    def __call__(self, state, u):
        return self.rollout(state, u)
        # return self.log_pdf(u)

    def log_pdf(self, u):
        traj, cost = self.rollout(u)
        return -cost

    def rollout(self, state, u):
        N = u.size(0) if u.ndim > 1 else 1
        u = u.view(N, -1, self.robot.u_dim)

        self.robot.reset(state)

        x = self.robot.rollout(u)

        state = Robot2DState(pos=x[..., :2], vel=x[..., 2:], u=u)
        # cost = self.cost(state)

        return state  # , cost
