import carla
import numpy as np

from ..carla_manager import Command
from .agents.navigation.global_route_planner import GlobalRoutePlanner
from .base_planner import BasePlanner


class FourLaneCommandPlanner(BasePlanner):
    """
    Generate waypoints given a high-level command on the four lane system.

    TODO
    """

    def __init__(self, vehicle: carla.Actor, lane_centers, sampling_radius=0.8, force_reset_command=True):
        super(FourLaneCommandPlanner, self).__init__(vehicle)
        self._sampling_radius = sampling_radius
        self._lane_centers = np.array(lane_centers)
        self.force_reset_command = force_reset_command

        # The distance to the goal waypoint on y-axis
        self._lane_follow_distance = 18
        self._dist_before_lane_change = 5
        self._dist_during_lane_change = 8
        self._dist_after_lane_change = 12

        self._grp = GlobalRoutePlanner(self._map, self._sampling_radius)
        self._is_command_valid = True
        self._num_waypoints = 0
        self._total_distance = 0.0
        self._command = None

    def set_command(self, command):
        """
        Set the current target command for the ego vehicle.
        :param command: Command
        """
        if self.force_reset_command or (self._command is None or self.is_command_completed()):
            self._command = command
            self._is_command_valid = True
            self._plan_route()

    def init_route(self):
        pass

    def extend_route(self):
        if self._is_plan_outdated():
            if self.get_ego_lane_id() != self._goal_lane_id:
                self._is_command_valid = False
                return
            self._command = Command.LaneFollow
            self._plan_route()

    def _plan_route(self):
        self.clear_waypoints()
        start_waypoint = self._map.get_waypoint(self._vehicle.get_location())
        start_location = start_waypoint.transform.location
        if self._command == Command.LaneFollow:
            self._extend_straight(start_location, self._lane_follow_distance)
            self._goal_lane_id = self.get_lane_id(start_location.x)
        else:
            current_lane_id = self.get_lane_id(start_location.x)
            if self._command == Command.LaneChangeLeft:
                if current_lane_id == 0:
                    self._is_command_valid = False
                    return
                goal_lane_id = current_lane_id - 1
            else:
                if current_lane_id == len(self._lane_centers) - 1:
                    self._is_command_valid = False
                    return
                goal_lane_id = current_lane_id + 1
            self._goal_lane_id = goal_lane_id
            self._extend_lane_change(start_location, goal_lane_id)
        # Remove the first waypoint
        self.pop_waypoint()

        waypoints = self.get_all_waypoints()
        self._num_waypoints = len(waypoints)
        start_location = self._vehicle.get_location()
        end_location = carla.Location(x=waypoints[-1][0], y=waypoints[-1][1], z=start_location.z)
        self._total_distance = start_location.distance(end_location)
        self._start_location = start_location
        self._end_location = carla.Location(x=waypoints[-1][0], y=waypoints[-1][1], z=start_location.z)

    def _is_plan_outdated(self):
        waypoints = self.get_all_waypoints()
        return len(waypoints) <= 8 or self._vehicle.get_location().y < waypoints[-1][1]

    def get_lane_id(self, x):
        return np.argmin(np.abs(self._lane_centers - x))

    def _extend_straight(self, start_location, distance):
        goal_location = carla.Location(start_location.x, start_location.y - distance, start_location.z)
        route = self._grp.trace_route(start_location, goal_location)
        for waypoint in route:
            self.add_waypoint(waypoint[0])

    def _extend_lane_change(self, start_location, goal_lane_id):
        self._extend_straight(start_location, self._dist_before_lane_change)
        new_location = carla.Location(start_location.x, start_location.y - self._dist_before_lane_change, start_location.z)
        dx = self._lane_centers[goal_lane_id] - new_location.x
        dy = -self._dist_during_lane_change
        yaw = np.arctan2(dy, dx) * 180 / np.pi
        dist = np.sqrt(dx**2 + dy**2)
        sample_num = max(2, int(dist / self._sampling_radius))
        for theta in np.linspace(0, 1, sample_num):
            x = new_location.x + theta * dx
            y = new_location.y + theta * dy
            self.add_waypoint((x, y, yaw))
        goal_location = carla.Location(self._lane_centers[goal_lane_id], new_location.y - self._dist_during_lane_change, new_location.z)
        self._extend_straight(goal_location, self._dist_after_lane_change)

    def is_command_valid(self):
        return self._is_command_valid

    def get_num_waypoints(self):
        return self._num_waypoints

    def get_total_distance(self):
        return self._total_distance

    def get_start_location(self):
        return self._start_location

    def get_end_location(self):
        return self._end_location

    def get_goal_lane_id(self):
        return self._goal_lane_id

    def get_ego_lane_id(self):
        return self.get_lane_id(self._vehicle.get_location().x)

    def is_command_completed(self):
        return self._is_plan_outdated() and self.get_ego_lane_id() == self._goal_lane_id
