import random
from typing import Optional

import carla
import numpy as np
from gym import spaces

from car_dreamer.toolkit import Command, Command2Index, EnvMonitorOpenCV, FourLaneCommandPlanner, Index2Command, Observer, WorldManager

from .carla_wpt_env import CarlaWptEnv


class CarlaMessageEnv(CarlaWptEnv):
    """
    An environment that requires the agent to follow a text command.

    TODO
    """

    def __init__(self, config, world: Optional[WorldManager] = None, observer: Optional[Observer] = None, monitor: Optional[EnvMonitorOpenCV] = None):
        super().__init__(config, world, observer, monitor)

        self._e_num_vehicles = self._config.num_vehicles
        self._e_command_repeat_counts = self._config.command_repeat_counts
        self._e_enable_replan = self._config.enable_replan
        self._has_obstacle = "obstacle_positions" in self._config
        self._observer.register_simple_handler(
            "worker_reward", lambda env_state: np.float32(self._get_completed_distance_ratio()), spaces.Box(low=0.0, high=1.0, dtype=np.float32)
        )
        self.observation_space = self._get_observation_space()

    def set_e_parameters(self, **kwargs):
        for key, value in kwargs.items():
            if getattr(self, f"_e_{key}", None) != value:
                setattr(self, f"_e_{key}", value)
                print(f"[CarlaMessageEnv] Setting _e_{key} to {value}!")

    def _get_completed_distance_ratio(self):
        if "completed_distance_ratio" not in self.goal_progress_info:
            return 0.0
        return self.goal_progress_info["completed_distance_ratio"]

    def get_state(self):
        state = super().get_state()
        return {**state, "dest_lane_idx": self.dest_lane, "command": self.command, "dest_x": self.dest_point[0], "task": "carla_message"}

    def on_reset(self) -> None:
        if self._has_obstacle:
            self.dest_lane = 1
            self.ego_src = self._config.lane_start_points[self.dest_lane]
            self.dest_point = self._config.lane_end_points[self.dest_lane]
            self.spawn_obstacle(transform_list=self._config.obstacle_positions)
        else:
            self.dest_lane = np.random.choice(len(self._config.lane_end_points))
            self.ego_src = self._config.lane_start_points[random.randint(0, len(self._config.lane_start_points) - 1)]
            self.dest_point = self._config.lane_end_points[self.dest_lane]
        ego_transform = carla.Transform(carla.Location(x=self.ego_src[0], y=self.ego_src[1], z=self.ego_src[2]), carla.Rotation(yaw=-90))
        self.ego = self._world.spawn_actor(transform=ego_transform)
        self._world.spawn_auto_actors(self._e_num_vehicles)
        self.ego_planner = FourLaneCommandPlanner(
            self.ego, [p[0] for p in self._config.lane_start_points], force_reset_command=self._config.force_reset_command
        )
        self.command = Command.LaneFollow
        self.ego_planner.set_command(Command.LaneFollow)
        self.waypoints, self.planner_stats = self.ego_planner.run_step()
        self.num_completed = self.planner_stats["num_completed"]
        self.command_not_completed = False
        self.command_completed_count = 0

        total_distance = self.ego_planner.get_total_distance()
        self.goal_progress_info = {
            "total_distance": total_distance,
            "dist_to_goal": total_distance,
            "progress_distance_ratio": 0.0,
            "progress_waypoint_ratio": 0.0,
            "completed_distance_ratio": 0.0,
            "completed_waypoint_ratio": 0.0,
        }

    def on_step(self):
        super().on_step()
        total_waypoints = self.ego_planner.get_num_waypoints()
        if total_waypoints > 0:
            progress_waypoints = self.num_completed
            progress_waypoint_ratio = progress_waypoints / total_waypoints
            self.goal_progress_info["progress_waypoint_ratio"] = progress_waypoint_ratio
            self.goal_progress_info["completed_waypoint_ratio"] += progress_waypoint_ratio

        current_location = self.ego.get_location()
        end_location = self.ego_planner.get_end_location()
        dist_to_goal = end_location.distance(current_location)
        progress_dist = self.goal_progress_info["dist_to_goal"] - dist_to_goal
        self.goal_progress_info["dist_to_goal"] = dist_to_goal
        progress_distance_ratio = progress_dist / self.goal_progress_info["total_distance"]
        self.goal_progress_info["progress_distance_ratio"] = progress_distance_ratio
        self.goal_progress_info["completed_distance_ratio"] += progress_distance_ratio

    def apply_control(self, action) -> None:
        action_num = self.n_acc * self.n_steer
        command, control = action // action_num, action % action_num
        command = Index2Command[command]
        if command != self.command:
            if self.ego_planner.get_ego_lane_id() != self.ego_planner.get_goal_lane_id() and not self._e_enable_replan:
                self.command_not_completed = True
                return
            self.command = command
            self.command_completed_count = 0
            self.ego_planner.set_command(command)
            super().apply_control(control)
            return
        if self.ego_planner.get_ego_lane_id() == self.ego_planner.get_goal_lane_id():
            self.command_completed_count += 1
            if self.command_completed_count >= self._e_command_repeat_counts:
                self.command_completed_count = 0
                self.ego_planner.set_command(self.command)
        super().apply_control(control)

    def spawn_obstacle(self, transform_list) -> None:
        """
        Spawn still vehicles as obstacles.
        """
        for trans in transform_list:
            ob_transform = carla.Transform(carla.Location(x=trans[0], y=trans[1], z=trans[2]), carla.Rotation(yaw=-90))
            self._world.spawn_actor(transform=ob_transform)

    def reward(self):
        reward, info = super().reward()
        reward_scales = self._config.reward.scales

        if self._e_enable_replan:
            lane_diff = np.abs(self.dest_lane - self.ego_planner.get_ego_lane_id())
        else:
            lane_diff = np.abs(self.dest_lane - self.ego_planner.get_goal_lane_id())
        reward_degrade = 1 + 2.0 * lane_diff
        if reward > 0.0:
            reward /= reward_degrade

        r_invalid_command = 0.0
        if self.is_command_invalid():
            r_invalid_command = -reward_scales["invalid_command"]

        r_destination_reached = 0.0
        if self.is_road_end_reached():
            if self.is_dest_lane_reached():
                r_destination_reached = reward_scales["destination_reached"]
            else:
                r_destination_reached = -reward_scales["destination_reached"]

        if self._has_obstacle:
            r_destination_reached *= 2

        reward += r_invalid_command + r_destination_reached

        info = {
            **info,
            **self.goal_progress_info,
            "r_degrade": reward_degrade,
            "r_invalid_command": r_invalid_command,
            "r_destination_reached": r_destination_reached,
            "command": Command2Index[self.command],
        }
        return reward, info

    def get_terminal_conditions(self):
        info = super().get_terminal_conditions()
        info["invalid_command"] = self.is_command_invalid()
        info["destination_reached"] = self.is_road_end_reached() and self.is_dest_lane_reached()
        info["road_end_reached"] = self.is_road_end_reached()
        return info

    def is_road_end_reached(self):
        return self.ego.get_location().y < self.dest_point[1]

    def is_dest_lane_reached(self):
        return self.ego_planner.get_ego_lane_id() == self.dest_lane

    def is_command_invalid(self):
        if self._has_obstacle and self.ego_planner.get_ego_lane_id() == 3:
            return True
        return not self.ego_planner.is_command_valid() or self.command_not_completed

    def _get_action_space(self):
        action_config = self._config.action
        if action_config.discrete:
            self.n_steer = len(action_config.discrete_steer)
            self.n_acc = len(action_config.discrete_acc)
            return spaces.Discrete(self.n_steer * self.n_acc + action_config.n_commands)
        else:
            raise NotImplementedError("Continuous action space is not supported yet.")
