# Copyright (c) # Copyright (c) 2018-2020 CVC.
#
# This work is licensed under the terms of the MIT license.
# For a copy, see <https://opensource.org/licenses/MIT>.

""" This module contains a local planner to perform low-level waypoint following based on PID controllers. """

from enum import Enum
from collections import deque
import random
import numpy as np
import cv2
import copy
import time

import carla
from carla_env.agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO
from carla_env.agents.navigation.global_route_planner import GlobalRoutePlanner
from carla_env.agents.navigation.controller import VehiclePIDController
from carla_env.agents.tools.misc import draw_waypoints
from carla_env.utils.carla_sync_mode import CarlaSyncMode


class RoadOption(Enum):
    """
    RoadOption represents the possible topological configurations when moving from a segment of lane to other.
    """
    VOID = -1
    LEFT = 1
    RIGHT = 2
    STRAIGHT = 3
    LANEFOLLOW = 4
    CHANGELANELEFT = 5
    CHANGELANERIGHT = 6


class LocalPlanner(object):
    """
    LocalPlanner implements the basic behavior of following a trajectory of waypoints that is generated on-the-fly.
    The low-level motion of the vehicle is computed by using two PID controllers, one is used for the lateral control
    and the other for the longitudinal control (cruise speed).

    When multiple paths are available (intersections) this local planner makes a random choice.
    """

    # minimum distance to target waypoint as a percentage (e.g. within 90% of
    # total distance)
    MIN_DISTANCE_PERCENTAGE = 0.9

    def __init__(self, vehicle, opt_dict=None):
        """
        :param vehicle: actor to apply to local planner logic onto
        :param opt_dict: dictionary of arguments with the following semantics:
            dt -- time difference between physics control in seconds. This is typically fixed from server side
                  using the arguments -benchmark -fps=F . In this case dt = 1/F

            target_speed -- desired cruise speed in Km/h

            sampling_radius -- search radius for next waypoints in seconds: e.g. 0.5 seconds ahead

            lateral_control_dict -- dictionary of arguments to setup the lateral PID controller
                                    {'K_P':, 'K_D':, 'K_I':, 'dt'}

            longitudinal_control_dict -- dictionary of arguments to setup the longitudinal PID controller
                                        {'K_P':, 'K_D':, 'K_I':, 'dt'}
        """
        self._vehicle = vehicle
        self._map = self._vehicle.get_world().get_map()

        self._dt = None
        self._target_speed = None
        self._sampling_radius = None
        self._min_distance = None
        self._current_waypoint = None
        self._target_road_option = None
        self._next_waypoints = None
        self.target_waypoint = None
        self._vehicle_controller = None
        self._global_plan = None
        self.dynamic_velocity = False
        # queue with tuples of (waypoint, RoadOption)
        self._waypoints_queue = deque(maxlen=20000)
        self._buffer_size = 5
        self._waypoint_buffer = deque(maxlen=self._buffer_size)

        self.global_planner = None
        self.global_node = None
        self.pre_node = None

        # initializing controller
        self._init_controller(opt_dict)
        # self.sync_mode = CarlaSyncMode(
        #     self._vehicle.get_world(), fps=20
        # )


    def __del__(self):
        if self._vehicle:
            self._vehicle.destroy()
            print("Destroying ego-vehicle!")

    def reset_vehicle(self):
        self._vehicle = None
        print("Resetting ego-vehicle!")

    def _init_controller(self, opt_dict):
        """
        Controller initialization.

        :param opt_dict: dictionary of arguments.
        :return:
        """
        # default params
        self._dt = 1.0 / 20.0
        self._target_speed = 20.0  # Km/h
        self._sampling_radius = self._target_speed * 1 / 3.6  # 1 seconds horizon
        self._min_distance = self._sampling_radius * self.MIN_DISTANCE_PERCENTAGE
        self._max_brake = 0.3
        self._max_throt = 0.75

        self._max_steer = 0.8
        args_lateral_dict = {
            'K_P': 1.95,
            'K_D': 0.2,
            'K_I': 0.07,
            'dt': self._dt}
        args_longitudinal_dict = {
            'K_P': 1.0,
            'K_D': 0,
            'K_I': 0.05,
            'dt': self._dt}
        self._offset = 0

        # parameters overload
        if opt_dict:
            if 'dt' in opt_dict:
                self._dt = opt_dict['dt']
            if 'target_speed' in opt_dict:
                self._target_speed = opt_dict['target_speed']
            if 'sampling_radius' in opt_dict:
                self._sampling_radius = self._target_speed * \
                                        opt_dict['sampling_radius'] / 3.6
            if 'lateral_control_dict' in opt_dict:
                args_lateral_dict = opt_dict['lateral_control_dict']
            if 'longitudinal_control_dict' in opt_dict:
                args_longitudinal_dict = opt_dict['longitudinal_control_dict']
            if 'max_throttle' in opt_dict:
                self._max_throt = opt_dict['max_throttle']
            if 'max_brake' in opt_dict:
                self._max_brake = opt_dict['max_brake']
            if 'max_steering' in opt_dict:
                self._max_steer = opt_dict['max_steering']
            if 'offset' in opt_dict:
                self._offset = opt_dict['offset']

        if self.dynamic_velocity:
            self._max_brake = 1.0
            self._max_throt = 2.0

        self._current_waypoint = self._map.get_waypoint(self._vehicle.get_location())
        self._vehicle_controller = VehiclePIDController(self._vehicle,
                                                        args_lateral=args_lateral_dict,
                                                        args_longitudinal=args_longitudinal_dict,
                                                        offset=self._offset,
                                                        max_throttle=self._max_throt,
                                                        max_brake=self._max_brake,
                                                        max_steering=self._max_steer)

        self._global_plan = False

        # compute initial waypoints
        self._waypoints_queue.append((self._current_waypoint.next(self._sampling_radius)[0], RoadOption.LANEFOLLOW))

        self._target_road_option = RoadOption.LANEFOLLOW
        # fill waypoint trajectory queue
        self._compute_next_waypoints(k=200)

    def set_speed(self, speed):
        """
        Request new target speed.

        :param speed: new target speed in Km/h
        :return:
        """
        self._target_speed = speed

    def _compute_next_waypoints(self, k=1):
        """
        Add new waypoints to the trajectory queue.

        :param k: how many waypoints to compute
        :return:
        """
        # check we do not overflow the queue
        available_entries = self._waypoints_queue.maxlen - len(self._waypoints_queue)
        k = min(available_entries, k)

        for _ in range(k):
            last_waypoint = self._waypoints_queue[-1][0]
            next_waypoints = list(last_waypoint.next(self._sampling_radius))

            if len(next_waypoints) == 0:
                break
            elif len(next_waypoints) == 1:
                # only one option available ==> lanefollowing
                next_waypoint = next_waypoints[0]
                road_option = RoadOption.LANEFOLLOW
            else:
                # random choice between the possible options
                road_options_list = _retrieve_options(
                    next_waypoints, last_waypoint)
                road_option = random.choice(road_options_list)
                next_waypoint = next_waypoints[road_options_list.index(road_option)]

            self._waypoints_queue.append((next_waypoint, road_option))

    def set_global_plan(self, current_plan):
        """
        Resets the waypoint queue and buffer to match the new plan. Also
        sets the global_plan flag to avoid creating more waypoints

        :param current_plan: list of (carla.Waypoint, RoadOption)
        :return:
        """

        # Reset the queue
        self._waypoints_queue.clear()
        for elem in current_plan:
            self._waypoints_queue.append(elem)
        self._target_road_option = RoadOption.LANEFOLLOW

        # and the buffer
        self._waypoint_buffer.clear()
        for _ in range(self._buffer_size):
            if self._waypoints_queue:
                self._waypoint_buffer.append(
                    self._waypoints_queue.popleft())
            else:
                break

        self._global_plan = True

    def get_able_actions(self):
        vehicle_transform = self._vehicle.get_transform()
        _current_waypoint = self._map.get_waypoint(vehicle_transform.location)
        self._closest_node, self._closest_location = self._find_closest_in_node(_current_waypoint, self.global_node)
        able_actions = []
        for edge_keys in self.global_planner._graph.edges.keys():
            # if _closest_node == edge_keys[0]:
            #     self._world.debug.draw_string(
            #         carla.Location(*(self.global_planner._graph.nodes[edge_keys[1]]['vertex'])), f'{edge_keys[1]}',
            #         draw_shadow=False, color=carla.Color(r=255, g=0, b=0), life_time=10.0)

            if self._closest_node == edge_keys[0] and self.global_planner._graph.edges[edge_keys]['type'] not in able_actions:
                able_actions.append(self.global_planner._graph.edges[edge_keys]['type'])
        return able_actions

    def set_global_planner(self, global_planner: GlobalRoutePlanner, inital_transform, target_transform):
        self.global_planner = global_planner
        self.global_node = global_planner.node_vertexes
        self.inital_transform = inital_transform
        self.target_transform = target_transform

        # for node in self.global_planner._graph.nodes:
        #     self._world.debug.draw_string(carla.Location(*(self.global_planner._graph.nodes[node]['vertex'])),
        #                                   f'{node}', draw_shadow=False, color=carla.Color(r=255, g=0, b=0),
        #                                   life_time=100.0)

    def _find_closest_in_node(self, current_info, nodes=None):

        if isinstance(current_info, carla.Waypoint):
            current_location = current_info.transform.location
        else:
            current_location = current_info.location

        min_distance = float('inf')
        closest_index = -1
        closest_location = None
        for i, node in enumerate(self.global_node):
            distance = current_location.distance(carla.Location(*node))
            if distance < min_distance:
                min_distance = distance
                closest_index = i
                closest_location = carla.Location(*node)
        # print("Node Path", closest_index, current_waypoint.transform.location, closest_location, min_distance)
        return closest_index, closest_location

    def expert_high_level_run_step(self, next_node):
        node = self.pre_node
        path_edge = None
        action = None

        if (node, next_node) in self.global_planner._graph.edges.keys():
            action = self.global_planner._graph.edges[(node, next_node)]['type'].value
            path_edge = self.global_planner._graph.edges[(node, next_node)]
            self.pre_node = next_node
        # print(path_edge['type'], node)

        if path_edge is not None:
            return path_edge, self.pre_node

        return None, None

    def high_level_run_step(self, action):
        node = self.pre_node
        path_edge = None
        for edge_keys in self.global_planner._graph.edges.keys():
            if self._closest_node == edge_keys[0] and self.global_planner._graph.edges[edge_keys]['type'].value == action:
                path_edge = self.global_planner._graph.edges[edge_keys]
                next_node = edge_keys[1]
                self.pre_node = next_node
                break
        if path_edge is not None:
            return path_edge, self.pre_node
        # if path_edge is not None:
        #     if action != RoadOption.CHANGELANELEFT.value and action != RoadOption.CHANGELANERIGHT.value:
        #         if action == RoadOption.LANEFOLLOW.value:
        #             min_dist = 0.5
        #         else:
        #             min_dist = 5
        #         print(path_edge['path'])
        #         controls = self.route_run_step(path_edge['path'], min_dist)
        #     else:
        #         controls = self.change_lane_run_step(path_edge['exit_waypoint'])
        #     self._closest_node = node
        #
        #     able_actions = self.get_able_actions()
        #     return controls, able_actions

        return None, None

    def one_run_step(self, target_waypoint):
        vehicle_transform = self._vehicle.get_transform()
        self._current_waypoint = self._map.get_waypoint(vehicle_transform.location)

        control = self._vehicle_controller.run_step(self._target_speed * 0.85, target_waypoint)
        # draw_waypoints(self._vehicle.get_world(), [target_waypoint], self._vehicle.get_location().z + 1.0)

        return control

    def change_lane_run_step(self, target_waypoint):
        controls = []
        self.target_waypoint = target_waypoint
        count = 0

        while self._vehicle.get_transform().carla.get_transform().location.distance(target_waypoint.transform.location) > 2.0:
            control = self._vehicle_controller.run_step(self._target_speed * 0.85, self.target_waypoint)
            throttle, steer, brake = control.throttle, control.steer, control.brake

            vehicle_control = carla.VehicleControl(
                throttle=throttle,  # [0,1]
                steer=steer,  # [-1,1]
                brake=brake,  # [0,1]
                hand_brake=False,
                reverse=False,
                manual_gear_shift=False,
            )

            self._vehicle.apply_control(vehicle_control)

            count += 1
            if count >= 1000:
                controls = None
                break

        return controls

    def route_run_step(self, route, min_dist):
        controls = []
        target = route[-1]
        for i in range(len(route) - 1):
            self.target_waypoint = route[i + 1]
            control = self._vehicle_controller.run_step(self._target_speed * 0.85, self.target_waypoint)
            throttle, steer, brake = control.throttle, control.steer, control.brake

            vehicle_control = carla.VehicleControl(
                throttle=throttle,  # [0,1]
                steer=steer,  # [-1,1]
                brake=brake,  # [0,1]
                hand_brake=False,
                reverse=False,
                manual_gear_shift=False,
            )

            self._vehicle.apply_control(vehicle_control)

            if self._vehicle.get_transform().location.distance(target.transform.location) < min_dist:
                break

        count = 0
        self.target_waypoint = route[-1]
        while self._vehicle.get_transform().location.distance(target.transform.location) > min_dist:
            control = self._vehicle_controller.run_step(self._target_speed * 0.85, self.target_waypoint)
            throttle, steer, brake = control.throttle, control.steer, control.brake

            vehicle_control = carla.VehicleControl(
                throttle=throttle,  # [0,1]
                steer=steer,  # [-1,1]
                brake=brake,  # [0,1]
                hand_brake=False,
                reverse=False,
                manual_gear_shift=False,
            )

            self._vehicle.apply_control(vehicle_control)

            count += 1
            if count >= 1000:
                controls = None
                break

        return controls

    def run_step(self, debug=False):
        """
        Execute one step of local planning which involves running the longitudinal and lateral PID controllers to
        follow the waypoints trajectory.

        :param debug: boolean flag to activate waypoints debugging
        :return: control to be applied
        """

        # not enough waypoints in the horizon? => add more!
        if not self._global_plan and len(self._waypoints_queue) < int(self._waypoints_queue.maxlen * 0.5):
            self._compute_next_waypoints(k=100)

        if len(self._waypoints_queue) == 0 and len(self._waypoint_buffer) == 0:
            control = carla.VehicleControl()
            control.steer = 0.0
            control.throttle = 0.0
            control.brake = 1.0
            control.hand_brake = False
            control.manual_gear_shift = False

            return control

        #   Buffering the waypoints
        if not self._waypoint_buffer:
            for _ in range(self._buffer_size):
                if self._waypoints_queue:
                    self._waypoint_buffer.append(
                        self._waypoints_queue.popleft())
                else:
                    break

        # current vehicle waypoint
        vehicle_transform = self._vehicle.get_transform()
        self._current_waypoint = self._map.get_waypoint(vehicle_transform.location)
        # target waypoint
        self.target_waypoint = self._waypoint_buffer[0][0]

        # move using PID controllers
        control = self._vehicle_controller.run_step(self._target_speed, self.target_waypoint)

        # purge the queue of obsolete waypoints
        max_index = -1

        for i, (waypoint, _) in enumerate(self._waypoint_buffer):
            if waypoint.transform.location.distance(vehicle_transform.location) < self._min_distance:
                max_index = i
        if max_index >= 0:
            for i in range(max_index + 1):
                self._waypoint_buffer.popleft()

        self._vehicle.get_world().debug.draw_string(self.target_waypoint.transform.location, 'O',
            draw_shadow=False, color=carla.Color(r=255, g=0, b=0))
        # draw_waypoints(self._vehicle.get_world(), [self.target_waypoint], self._vehicle.get_location().z + 1.0)

        return control

    def set_pre_node(self, node):
        self.pre_node = node

    def get_speed(self):
        return self._vehicle_controller.get_speed()

    def done(self):
        """
        Returns whether or not the planner has finished

        :return: boolean
        """
        return len(self._waypoints_queue) == 0 and len(self._waypoint_buffer) == 0


def _retrieve_options(list_waypoints, current_waypoint):
    """
    Compute the type of connection between the current active waypoint and the multiple waypoints present in
    list_waypoints. The result is encoded as a list of RoadOption enums.

    :param list_waypoints: list with the possible target waypoints in case of multiple options
    :param current_waypoint: current active waypoint
    :return: list of RoadOption enums representing the type of connection from the active waypoint to each
             candidate in list_waypoints
    """
    options = []
    for next_waypoint in list_waypoints:
        # this is needed because something we are linking to
        # the beggining of an intersection, therefore the
        # variation in angle is small
        next_next_waypoint = next_waypoint.next(3.0)[0]
        link = _compute_connection(current_waypoint, next_next_waypoint)
        options.append(link)

    return options


def _compute_connection(current_waypoint, next_waypoint, threshold=35):
    """
    Compute the type of topological connection between an active waypoint (current_waypoint) and a target waypoint
    (next_waypoint).

    :param current_waypoint: active waypoint
    :param next_waypoint: target waypoint
    :return: the type of topological connection encoded as a RoadOption enum:
             RoadOption.STRAIGHT
             RoadOption.LEFT
             RoadOption.RIGHT
    """
    n = next_waypoint.transform.rotation.yaw
    n = n % 360.0

    c = current_waypoint.transform.rotation.yaw
    c = c % 360.0

    diff_angle = (n - c) % 180.0
    if diff_angle < threshold or diff_angle > (180 - threshold):
        return RoadOption.STRAIGHT
    elif diff_angle > 90.0:
        return RoadOption.LEFT
    else:
        return RoadOption.RIGHT
