import random
from statistics import mean
from typing import Dict, Optional, List, Tuple

import numpy as np
from numpy import ndarray
import math
import torch

from env.environment import Env
from misc_utils.colors import RED, CYAN
from RL.utils import preprocess_obs, get_preprocessed_obs_len
from sumo.traffic_state import Vehicle
from sumo.utils import is_internal_lane, get_remaining_time
from sumo.constants import SPEED_NORMALIZATION, GLOBAL_MAX_LANE_LENGTH, GLOBAL_MAX_SPEED, LANE_LENGTH_NORMALIZATION, \
    TL_CYCLE_NORMALIZATION


class NoStopRLEnv(Env):
    def step(self, ids: ndarray, action: ndarray, warmup: bool = False) -> Tuple[ndarray, ndarray, ndarray, bool]:
        if not warmup:
            for i, v_id in enumerate(ids):
                vehicle = self.traffic_state.vehicles[v_id]

                sumo_speed = self.traci.vehicle.getSpeedWithoutTraCI(v_id)
                diff = sumo_speed - vehicle.previous_step_idm_speed
                vehicle.previous_step_idm_speed = \
                    self.traffic_state.get_idm_accel(vehicle) * self.config.sim_step_duration + vehicle.speed

                # We only apply RL accel to the vehicle if SUMO doesn't slow down the vehicle to do a lane change
                # We identify if SUMO slows down to do lane change by computing theoretical IDM speed, and taking the
                # diff with the speed SUMO actually wants to apply.
                if not self.config.baseline_run and (abs(diff) < 0.5 or abs(sumo_speed - vehicle.speed) < 0.5):
                    if self.config.run_mode != "train":
                        if (vehicle.lane_id.startswith("A2TL") or vehicle.lane_id.startswith("B2TL") \
                            or vehicle.lane_id.startswith("C2TL") or vehicle.lane_id.startswith("D2TL") \
                                or vehicle.lane_id.startswith("E2TL") or vehicle.lane_id.startswith("F2TL")):
                            self.traffic_state.accel(vehicle, action[i]['accel'][0], use_speed_factor=False)
                        else:
                            idm_aceel = self.traffic_state.get_idm_accel(vehicle)
                            self.traffic_state.accel(vehicle, idm_aceel, use_speed_factor=False)
                    else:  
                        self.traffic_state.accel(vehicle, action[i]['accel'][0], use_speed_factor=False)

                if 'lane_change' in action and not is_internal_lane(vehicle.lane_id):
                    if self.config.run_mode != "train":
                        if (vehicle.lane_id.startswith("A2TL") or vehicle.lane_id.startswith("B2TL") \
                            or vehicle.lane_id.startswith("C2TL") or vehicle.lane_id.startswith("D2TL") \
                                or vehicle.lane_id.startswith("E2TL") or vehicle.lane_id.startswith("F2TL")):
                            # action is discrete from 0 to 2, lane change is -1 (left), 0 (stay) or 1 (right)
                            lane_change = action[i]['lane_change'] - 1
                            vehicle.change_lane_relative(lane_change)
                    else:
                        lane_change = action[i]['lane_change'] - 1
                        vehicle.change_lane_relative(lane_change)

                if vehicle.direction != 0:
                    self.traffic_state.set_color(vehicle, CYAN)
                else:
                    self.traffic_state.set_color(vehicle, RED)


        super().step(ids, action, warmup)

        reward = self.get_reward(ids)

        # When we control lane changes, it is possible that vehicles get stuck because our strategic (see SUMO doc)
        # lane changes are bad, thus we remove them to make the simulation going.
        if self.config.control_lane_change:
            for i, v_id in enumerate(ids):
                if v_id in self.traffic_state.current_vehicles:
                    vehicle = self.traffic_state.vehicles[v_id]
                    if (-vehicle.length < vehicle.relative_distance <= 0
                            and vehicle.closest_optim_lane_distance != 0
                            and vehicle.closest_optim_lane_distance is not None
                            and vehicle.speed < 1e-2):
                        self.traffic_state.remove_vehicle(vehicle)
                        reward[i] = -5000
                        self.vehicle_metrics[vehicle.id]['removed'] = [1]

        new_ids = sorted([v_id for v_id, vehicle in self.traffic_state.current_vehicles.items() if vehicle.is_rl])
        obs = self.get_obs(new_ids)

        assert reward.shape[0] == len(ids)
        assert obs.shape[0] == len(new_ids)

        # We consider the simulation over is too many vehicles are present, and they are too slow
        done = not warmup and len(new_ids) > 150 and self.get_average_speed() < 0.5

        return np.array(new_ids), obs, reward, done

    def get_obs(self, current_rl_vehicle_list: List[str]) -> np.ndarray:
        obs = np.empty((len(current_rl_vehicle_list), get_preprocessed_obs_len(self.obs_space)), dtype=np.float32)
        time = self.traffic_state.traci_module.simulation.getTime()
        for i, vehicle_id in enumerate(current_rl_vehicle_list):
            vehicle = self.traffic_state.vehicles[vehicle_id]
            obs[i] = preprocess_obs(self.get_vehicle_obs(vehicle, time), self.obs_space)

        return obs

    def get_vehicle_obs(self, vehicle: Vehicle, time: float) -> Dict[str, float]:
        """
        Returns the obs dict for a single vehicle
        """
        current_phase = self.traffic_state.get_phase("TL")
        lane_index = vehicle.lane_index

        if lane_index > 0:
            leader_right = self.traffic_state.get_leader(vehicle, side_lane=-1)
            follower_right = self.traffic_state.get_follower(vehicle, side_lane=-1)
        else:
            leader_right = vehicle
            follower_right = vehicle

        if vehicle.edge_id not in self.traffic_state.lane_counts or \
                lane_index >= self.traffic_state.lane_counts[vehicle.edge_id] - 1:
            leader_left = vehicle
            follower_left = vehicle
        else:
            leader_left = self.traffic_state.get_leader(vehicle, side_lane=1)
            follower_left = self.traffic_state.get_follower(vehicle, side_lane=1)

        def get_other_vehicle_obs(other_veh: Optional[Vehicle]) -> Dict[str, any]:
            if other_veh is None:
                return {
                    'speed': GLOBAL_MAX_SPEED / SPEED_NORMALIZATION,
                    'relative_position': GLOBAL_MAX_LANE_LENGTH / LANE_LENGTH_NORMALIZATION,
                    'blinker_left': False,
                    'blinker_right': False
                }
            elif other_veh.id == vehicle.id:
                return {
                    'speed': vehicle.speed / SPEED_NORMALIZATION,
                    'relative_position': 0,
                    'blinker_left': False,
                    'blinker_right': False
                }
            else:
                return {
                    'speed': other_veh.speed / SPEED_NORMALIZATION,
                    'relative_position': self.traffic_state.get_linear_distance(other_veh,
                                                                                vehicle) / LANE_LENGTH_NORMALIZATION,
                    'blinker_left': other_veh.turn_signal == 1,
                    'blinker_right': other_veh.turn_signal == -1
                }

        green_phase_transition = get_remaining_time(vehicle.green_phase_timings, time) \
            if current_phase == vehicle.green_phase_index \
            else get_remaining_time(vehicle.green_phase_timings, time) - vehicle.green_phase_timings[1]

        obs = {
            'speed': vehicle.speed / SPEED_NORMALIZATION,
            'relative_distance': vehicle.relative_distance / LANE_LENGTH_NORMALIZATION,
            'tl_phase': 0 if current_phase == vehicle.green_phase_index else (
                1 if current_phase == vehicle.green_phase_index + 1 else 2),
            'time_remaining': get_remaining_time(vehicle.green_phase_timings, time)
                              / TL_CYCLE_NORMALIZATION,
            'time_remaining2': green_phase_transition + sum(vehicle.green_phase_timings)
                               / TL_CYCLE_NORMALIZATION,
            'time_remaining3': green_phase_transition + sum(vehicle.green_phase_timings) * 2
                               / TL_CYCLE_NORMALIZATION,
            # E and W: 0, S and N: 1, exiting lanes: 2, internal lanes: 3
            'edge_id': 0 if vehicle.edge_id.endswith('2TL') else (1 if is_internal_lane(vehicle.lane_id) else 2),
            'follower': get_other_vehicle_obs(self.traffic_state.get_follower(vehicle)),
            'leader': get_other_vehicle_obs(self.traffic_state.get_leader(vehicle)),
            'lane_index': min(lane_index / max(self.traffic_state.lane_counts[vehicle.edge_id] - 1, 1), 1) \
                if vehicle.edge_id in self.traffic_state.lane_counts else 0.5,
            # 0 for right turn, 1 for no turn necessary, 2 for left turn
            'destination': vehicle.direction + 1,
            'leader_right': get_other_vehicle_obs(leader_right),
            'follower_right': get_other_vehicle_obs(follower_right),
            'leader_left': get_other_vehicle_obs(leader_left),
            'follower_left': get_other_vehicle_obs(follower_left),

            # context, stays constant
            'penetration_rate': self.task_context.penetration_rate,
            'green_phase': vehicle.green_phase_timings[1],
            'red_phase': vehicle.green_phase_timings[0] + vehicle.green_phase_timings[2],
            'speed_limit': self.traffic_state.get_speed_limit(vehicle.origin + '2TL_0'),
            'lane_length': self.traffic_state.get_lane_length(vehicle.origin + '2TL_0')
        }
        return obs

    def get_reward(self, vehicle_list: ndarray) -> np.ndarray:
        """Compute the reward of the previous action."""

        def individual_reward(veh: Vehicle) -> float:
            penalty = 0

            speed = veh.speed
            threshold = self.config.threshold

            # fix the road grade, tecmperature and humidity
            x = np.array([speed, veh.accel, 0, 68, 46])
            x = torch.from_numpy(x)
            emission = (self.config.emissions_model(x.float())).item()
            
            if self.config.stop_penalty is not None and speed < threshold:
                penalty += self.config.stop_penalty * (threshold - speed) / threshold

            if self.config.accel_penalty is not None:
                penalty += self.config.accel_penalty * abs(veh.accel)

            if self.config.emission_penalty is not None:
                penalty += self.config.emission_penalty * emission

            return (speed - penalty) / SPEED_NORMALIZATION

        fleet_rewards = {
            k1: {
                k2: mean(individual_reward(v) for v in v2)
                for k2, v2 in v1.items()
            }
            for k1, v1 in self.traffic_state.current_vehicles_sorted_lists.items()
        }

        result = np.empty(vehicle_list.shape)
        num_stopped_vehicles = 0
        for i, v_id in enumerate(vehicle_list):
            vehicle = self.traffic_state.vehicles[v_id]
            num_stopped_vehicles += vehicle.speed < self.config.threshold
            result[i] = (fleet_rewards[vehicle.platoon][vehicle.lane_index]
                         if random.random() < self.config.fleet_reward_ratio
                            and vehicle.platoon in fleet_rewards
                            and vehicle.lane_index in fleet_rewards[vehicle.platoon]
                         else individual_reward(vehicle))

        if self.config.fleet_stop_penalty is not None and result.shape[0] > 0:
            result -= self.config.fleet_stop_penalty * num_stopped_vehicles / result.shape[0]

        return result
