"""This module contains the TrafficSignal class, which represents a traffic signal in the simulation."""
import os
import sys
from typing import Callable, List, Union


if "SUMO_HOME" in os.environ:
    tools = os.path.join(os.environ["SUMO_HOME"], "tools")
    sys.path.append(tools)
else:
    raise ImportError("Please declare the environment variable 'SUMO_HOME'")
import numpy as np
from gymnasium import spaces

import pdb
class TrafficSignal:
    """This class represents a Traffic Signal controlling an intersection.

    It is responsible for retrieving information and changing the traffic phase using the Traci API.

    IMPORTANT: It assumes that the traffic phases defined in the .net file are of the form:
        [green_phase, yellow_phase, green_phase, yellow_phase, ...]
    Currently it is not supporting all-red phases (but should be easy to implement it).

    # Observation Space
    The default observation for each traffic signal agent is a vector:

    obs = [phase_one_hot, min_green, lane_1_density,...,lane_n_density, lane_1_queue,...,lane_n_queue]

    - ```phase_one_hot``` is a one-hot encoded vector indicating the current active green phase
    - ```min_green``` is a binary variable indicating whether min_green seconds have already passed in the current phase
    - ```lane_i_density``` is the number of vehicles in incoming lane i dividided by the total capacity of the lane
    - ```lane_i_queue``` is the number of queued (speed below 0.1 m/s) vehicles in incoming lane i divided by the total capacity of the lane

    You can change the observation space by implementing a custom observation class. See :py:class:`sumo_rl.environment.observations.ObservationFunction`.

    # Action Space
    Action space is discrete, corresponding to which green phase is going to be open for the next delta_time seconds.

    # Reward Function
    The default reward function is 'diff-waiting-time'. You can change the reward function by implementing a custom reward function and passing to the constructor of :py:class:`sumo_rl.environment.env.SumoEnvironment`.
    """

    # Default min gap of SUMO (see https://sumo.dlr.de/docs/Simulation/Safety.html). Should this be parameterized?
    MIN_GAP = 2.5

    def __init__(
        self,
        env,
        ts_id: str,
        delta_time: int,
        yellow_time: int,
        min_green: int,
        max_green: int,
        begin_time: int,
        reward_fn: Union[str, Callable],
        sumo,
        reward_exponent: float = 1.0, ## Only for _waiting_time_reward_per_road_big_nonlinear
        reward_scale_factor: float = 1000.0, ## Only for _waiting_time_reward_per_road_big_nonlinear
    ):
        """Initializes a TrafficSignal object.

        Args:
            env (SumoEnvironment): The environment this traffic signal belongs to.
            ts_id (str): The id of the traffic signal.
            delta_time (int): The time in seconds between actions.
            yellow_time (int): The time in seconds of the yellow phase.
            min_green (int): The minimum time in seconds of the green phase.
            max_green (int): The maximum time in seconds of the green phase.
            begin_time (int): The time in seconds when the traffic signal starts operating.
            reward_fn (Union[str, Callable]): The reward function. Can be a string with the name of the reward function or a callable function.
            sumo (Sumo): The Sumo instance.
        """
        self.id = ts_id
        self.env = env
        self.delta_time = delta_time
        self.yellow_time = yellow_time
        self.min_green = min_green
        self.max_green = max_green
        self.green_phase = 0
        self.is_yellow = False
        self.time_since_last_phase_change = 0
        self.next_action_time = begin_time
        self.last_measure = 0.0
        self.last_reward = None
        self.reward_fn = reward_fn
        self.sumo = sumo

        ### Newly added for reward scaling
        self.reward_exponent = reward_exponent
        self.reward_scale_factor = reward_scale_factor

        if type(self.reward_fn) is str:
            if self.reward_fn in TrafficSignal.reward_fns.keys():
                self.reward_fn = TrafficSignal.reward_fns[self.reward_fn]
            else:
                raise NotImplementedError(f"Reward function {self.reward_fn} not implemented")

        self.observation_fn = self.env.observation_class(self)

        self._build_phases()

        self.lanes = list(
            dict.fromkeys(self.sumo.trafficlight.getControlledLanes(self.id))
        )  # Remove duplicates and keep order

        ## NOTE.
        print(f"number of lanes is {len(self.lanes)}")

        self.out_lanes = [link[0][1] for link in self.sumo.trafficlight.getControlledLinks(self.id) if link]
        self.out_lanes = list(set(self.out_lanes))
        self.lanes_length = {lane: self.sumo.lane.getLength(lane) for lane in self.lanes + self.out_lanes}

        self.observation_space = self.observation_fn.observation_space()
        self.action_space = spaces.Discrete(self.num_green_phases)
        ##
        self.edges = self._compute_edges()
        self.edges_big = self._compute_edges_big()


    def _build_phases(self):
        phases = self.sumo.trafficlight.getAllProgramLogics(self.id)[0].phases
        # (Phase(duration=33.0, state='GGrrrrGGrrrr', minDur=33.0, maxDur=33.0),
        #  Phase(duration=2.0, state='yyrrrryyrrrr', minDur=2.0, maxDur=2.0),
        #  Phase(duration=6.0, state='rrGrrrrrGrrr', minDur=6.0, maxDur=6.0),
        #  Phase(duration=2.0, state='rryrrrrryrrr', minDur=2.0, maxDur=2.0),
        #  Phase(duration=33.0, state='rrrGGrrrrGGr', minDur=33.0, maxDur=33.0),
        #  Phase(duration=2.0, state='rrryyrrrryyr', minDur=2.0, maxDur=2.0),
        #  Phase(duration=6.0, state='rrrrrGrrrrrG', minDur=6.0, maxDur=6.0),
        #  Phase(duration=2.0, state='rrrrryrrrrry', minDur=2.0, maxDur=2.0))
        if self.env.fixed_ts: # False
            self.num_green_phases = len(phases) // 2  # Number of green phases == number of phases (green+yellow) divided by 2
            return

        self.green_phases = []
        self.yellow_dict = {}
        for phase in phases:
            state = phase.state # 'GGrrrrGGrrrr'
            if "y" not in state and (state.count("r") + state.count("s") != len(state)):
                self.green_phases.append(self.sumo.trafficlight.Phase(60, state))
        self.num_green_phases = len(self.green_phases) # 4
        self.all_phases = self.green_phases.copy()
        # [Phase(duration=60, state='GGrrrrGGrrrr', minDur=-1, maxDur=-1),
        #  Phase(duration=60, state='rrGrrrrrGrrr', minDur=-1, maxDur=-1),
        #  Phase(duration=60, state='rrrGGrrrrGGr', minDur=-1, maxDur=-1),
        #  Phase(duration=60, state='rrrrrGrrrrrG', minDur=-1, maxDur=-1)]

        for i, p1 in enumerate(self.green_phases):
            for j, p2 in enumerate(self.green_phases):
                if i == j:
                    continue
                yellow_state = ""
                for s in range(len(p1.state)):
                    if (p1.state[s] == "G" or p1.state[s] == "g") and (p2.state[s] == "r" or p2.state[s] == "s"):
                        yellow_state += "y"
                    else:
                        yellow_state += p1.state[s]
                self.yellow_dict[(i, j)] = len(self.all_phases)
                self.all_phases.append(self.sumo.trafficlight.Phase(self.yellow_time, yellow_state))

        programs = self.sumo.trafficlight.getAllProgramLogics(self.id)
        logic = programs[0]
        logic.type = 0
        logic.phases = self.all_phases
        self.sumo.trafficlight.setProgramLogic(self.id, logic)
        self.sumo.trafficlight.setRedYellowGreenState(self.id, self.all_phases[0].state) # GGrrrrGGrrrr 2-way, GGGGrrrrrrGGGGrrrrrr Big Inter (rrrrGrrrrrrrrrGrrrrr)
        ### logic
        # Logic(programID='0', type=0, currentPhaseIndex=0,
        #   phases=[Phase(duration=60, state='GGrrrrGGrrrr', minDur=-1, maxDur=-1),
        #           Phase(duration=60, state='rrGrrrrrGrrr', minDur=-1, maxDur=-1),
        #           Phase(duration=60, state='rrrGGrrrrGGr', minDur=-1, maxDur=-1),
        #           Phase(duration=60, state='rrrrrGrrrrrG', minDur=-1, maxDur=-1),
        #           Phase(duration=4, state='yyrrrryyrrrr', minDur=-1, maxDur=-1), # yellow for phase 0 -> 1
        #           Phase(duration=4, state='yyrrrryyrrrr', minDur=-1, maxDur=-1), # yellow for phase 0 -> 2
        #           Phase(duration=4, state='yyrrrryyrrrr', minDur=-1, maxDur=-1), # yellow for phase 0 -> 3
        #           Phase(duration=4, state='rryrrrrryrrr', minDur=-1, maxDur=-1), # yellow for phase 1 -> 0
        #           Phase(duration=4, state='rryrrrrryrrr', minDur=-1, maxDur=-1), # yellow for phase 1 -> 2
        #           Phase(duration=4, state='rryrrrrryrrr', minDur=-1, maxDur=-1), # yellow for phase 1 -> 3
        #           Phase(duration=4, state='rrryyrrrryyr', minDur=-1, maxDur=-1), # yellow for phase 2 -> 0
        #           Phase(duration=4, state='rrryyrrrryyr', minDur=-1, maxDur=-1), # yellow for phase 2 -> 1
        #           Phase(duration=4, state='rrryyrrrryyr', minDur=-1, maxDur=-1), # yellow for phase 2 -> 3
        #           Phase(duration=4, state='rrrrryrrrrry', minDur=-1, maxDur=-1), # yellow for phase 3 -> 0
        #           Phase(duration=4, state='rrrrryrrrrry', minDur=-1, maxDur=-1), # yellow for phase 3 -> 1
        #           Phase(duration=4, state='rrrrryrrrrry', minDur=-1, maxDur=-1)], subParameter={}) # yellow for phase 3 -> 2

    @property
    def time_to_act(self):
        """Returns True if the traffic signal should act in the current step."""
        return self.next_action_time == self.env.sim_step

    def update(self):
        """Updates the traffic signal state.

        If the traffic signal should act, it will set the next green phase and update the next action time.
        """
        self.time_since_last_phase_change += 1
        if self.is_yellow and self.time_since_last_phase_change == self.yellow_time:
            # self.sumo.trafficlight.setPhase(self.id, self.green_phase)
            self.sumo.trafficlight.setRedYellowGreenState(self.id, self.all_phases[self.green_phase].state)
            self.is_yellow = False

    def set_next_phase(self, new_phase: int):
        """Sets what will be the next green phase and sets yellow phase if the next phase is different than the current.

        Args:
            new_phase (int): Number between [0 ... num_green_phases]
        """
        new_phase = int(new_phase)
        if self.green_phase == new_phase or self.time_since_last_phase_change < self.yellow_time + self.min_green:
            # self.sumo.trafficlight.setPhase(self.id, self.green_phase)
            self.sumo.trafficlight.setRedYellowGreenState(self.id, self.all_phases[self.green_phase].state)
            self.next_action_time = self.env.sim_step + self.delta_time
        else:
            # self.sumo.trafficlight.setPhase(self.id, self.yellow_dict[(self.green_phase, new_phase)])  # turns yellow
            self.sumo.trafficlight.setRedYellowGreenState(
                self.id, self.all_phases[self.yellow_dict[(self.green_phase, new_phase)]].state
            )
            self.green_phase = new_phase
            self.next_action_time = self.env.sim_step + self.delta_time
            self.is_yellow = True
            self.time_since_last_phase_change = 0

    def compute_observation(self):
        """Computes the observation of the traffic signal."""
        return self.observation_fn()

    def compute_reward(self):
        """Computes the reward of the traffic signal."""
        self.last_reward = self.reward_fn(self)
        return self.last_reward

    def _pressure_reward(self):
        return self.get_pressure()

    def _average_speed_reward(self):
        return self.get_average_speed()

    def _queue_reward(self):
        return -self.get_total_queued()

    def _diff_waiting_time_reward(self):
        ts_wait = sum(self.get_accumulated_waiting_time_per_lane()) / 100.0
        reward = self.last_measure - ts_wait
        self.last_measure = ts_wait
        return reward

    def _waiting_time_reward_per_lane(self): # 8d
        reward = np.array(self.get_accumulated_waiting_time_per_lane()) / 100.0
        self.last_measure = reward
        return -reward ## - is added

    def _waiting_time_reward_per_road(self)-> List[float]: # 4d
        ts_wait = np.array(self.get_accumulated_waiting_time_per_road()) / 100.0

        # ### Added
        # # print("ts_wait before", ts_wait)
        # power = 1.5
        # coef = 1 # for RL training
        # ts_wait = coef*np.power(ts_wait, power)
        # # print("ts_wait after", ts_wait)
        # # print()

        self.last_measure = ts_wait
        return -ts_wait

    def _waiting_time_reward_per_road_big(self)-> List[float]: # 4d
        ts_wait = np.array(self.get_accumulated_waiting_time_per_road_big()) / 1000.0 # scale
        self.last_measure = ts_wait
        return -ts_wait

    def _waiting_time_reward_per_road_big_nonlinear(self)-> List[float]: # 4d
        ts_wait = np.array(self.get_accumulated_waiting_time_per_road_big())

        # ### Added version 1
        # print("ts_wait before", ts_wait)
        # ts_wait = np.power(ts_wait, self.reward_exponent) / self.reward_scale_factor
        # print("ts_wait after", ts_wait)
        # print()

        ### Added version 2
        # print("ts_wait before", ts_wait)
        ts_wait = np.power(ts_wait / self.reward_scale_factor, self.reward_exponent)
        # print("ts_wait after", ts_wait)
        # print()

        self.last_measure = ts_wait
        return -ts_wait

    def _observation_fn_default(self):
        phase_id = [1 if self.green_phase == i else 0 for i in range(self.num_green_phases)]  # one-hot encoding
        min_green = [0 if self.time_since_last_phase_change < self.min_green + self.yellow_time else 1]
        density = self.get_lanes_density()
        queue = self.get_lanes_queue()
        observation = np.array(phase_id + min_green + density + queue, dtype=np.float32)
        return observation

    def get_accumulated_waiting_time_per_lane(self) -> List[float]:
        """Returns the accumulated waiting time per lane.

        Returns:
            List[float]: List of accumulated waiting time of each intersection lane.
        """
        wait_time_per_lane = []
        for lane in self.lanes:
            veh_list = self.sumo.lane.getLastStepVehicleIDs(lane)
            wait_time = 0.0
            for veh in veh_list:
                veh_lane = self.sumo.vehicle.getLaneID(veh)
                acc = self.sumo.vehicle.getAccumulatedWaitingTime(veh)
                if veh not in self.env.vehicles:
                    self.env.vehicles[veh] = {veh_lane: acc}
                else:
                    self.env.vehicles[veh][veh_lane] = acc - sum(
                        [self.env.vehicles[veh][lane] for lane in self.env.vehicles[veh].keys() if lane != veh_lane]
                    )
                wait_time += self.env.vehicles[veh][veh_lane]
            wait_time_per_lane.append(wait_time)
        return wait_time_per_lane

    #### 2way-single-intersection
    def get_accumulated_waiting_time_per_road(self) -> List[float]:
        wait_time_per_road = []
        for p in range(self.num_green_phases): # self.num_green_phases 4 (direction)
            veh_list = self._get_veh_list_road(p)
            # ['flow_ns_0.5', 'flow_ne_0.8', 'flow_ne_0.7', 'flow_ne_0.6', 'flow_ne_0.5', 'flow_ne_0.4', 'flow_ne_0.3', 'flow_ne_0.2', 'flow_ne_0.1', 'flow_ne_0.0']
            wait_time = 0.0
            for veh in veh_list:
                veh_lane = self.sumo.vehicle.getLaneID(veh)
                acc = self.sumo.vehicle.getAccumulatedWaitingTime(veh)
                if veh not in self.env.vehicles:
                    self.env.vehicles[veh] = {veh_lane: acc}
                else:
                    self.env.vehicles[veh][veh_lane] = acc - sum([self.env.vehicles[veh][lane] for lane in self.env.vehicles[veh].keys() if lane != veh_lane])
                wait_time += self.env.vehicles[veh][veh_lane]
            wait_time_per_road.append(wait_time)
        return wait_time_per_road
    def _get_veh_list_road(self, p):
        veh_list = []
        for lane in self.edges[p]:
            veh_list += self.sumo.lane.getLastStepVehicleIDs(lane)
        return veh_list

    def _compute_edges(self):
        """
        return: Dict green phase to edge id
        """
        return {p : self.lanes[p*2:p*2+2] for p in range(self.num_green_phases)}  # two lanes per edge

    ### big-intersection
    def get_accumulated_waiting_time_per_road_big(self) -> List[float]:
        wait_time_per_road = []
        for p in range(self.num_green_phases): # self.num_green_phases 4 (direction)
            veh_list = self._get_veh_list_road_big(p)
            wait_time = 0.0
            for veh in veh_list:
                veh_lane = self.sumo.vehicle.getLaneID(veh)
                acc = self.sumo.vehicle.getAccumulatedWaitingTime(veh)
                if veh not in self.env.vehicles:
                    self.env.vehicles[veh] = {veh_lane: acc}
                else:
                    self.env.vehicles[veh][veh_lane] = acc - sum([self.env.vehicles[veh][lane] for lane in self.env.vehicles[veh].keys() if lane != veh_lane])
                wait_time += self.env.vehicles[veh][veh_lane]
            wait_time_per_road.append(wait_time)
        return wait_time_per_road
    def _get_veh_list_road_big(self, p):
        veh_list = []
        for lane in self.edges_big[p]:
            veh_list += self.sumo.lane.getLastStepVehicleIDs(lane)
        return veh_list

    def _compute_edges_big(self):
        """
        return: Dict green phase to edge id
        """
        return {p : self.lanes[p*4:(p+1)*4] for p in range(self.num_green_phases)}  # two lanes per edge

    def get_average_speed(self) -> float:
        """Returns the average speed normalized by the maximum allowed speed of the vehicles in the intersection.

        Obs: If there are no vehicles in the intersection, it returns 1.0.
        """
        avg_speed = 0.0
        vehs = self._get_veh_list()
        if len(vehs) == 0:
            return 1.0
        for v in vehs:
            avg_speed += self.sumo.vehicle.getSpeed(v) / self.sumo.vehicle.getAllowedSpeed(v)
        return avg_speed / len(vehs)

    def get_pressure(self):
        """Returns the pressure (#veh leaving - #veh approaching) of the intersection."""
        return sum(self.sumo.lane.getLastStepVehicleNumber(lane) for lane in self.out_lanes) - sum(
            self.sumo.lane.getLastStepVehicleNumber(lane) for lane in self.lanes
        )

    def get_out_lanes_density(self) -> List[float]:
        """Returns the density of the vehicles in the outgoing lanes of the intersection."""
        lanes_density = [
            self.sumo.lane.getLastStepVehicleNumber(lane)
            / (self.lanes_length[lane] / (self.MIN_GAP + self.sumo.lane.getLastStepLength(lane)))
            for lane in self.out_lanes
        ]
        return [min(1, density) for density in lanes_density]

    def get_lanes_density(self) -> List[float]:
        """Returns the density [0,1] of the vehicles in the incoming lanes of the intersection.

        Obs: The density is computed as the number of vehicles divided by the number of vehicles that could fit in the lane.
        """
        lanes_density = [
            self.sumo.lane.getLastStepVehicleNumber(lane)
            / (self.lanes_length[lane] / (self.MIN_GAP + self.sumo.lane.getLastStepLength(lane)))
            for lane in self.lanes
        ]
        return [min(1, density) for density in lanes_density]

    def get_lanes_queue(self) -> List[float]:
        """Returns the queue [0,1] of the vehicles in the incoming lanes of the intersection.

        Obs: The queue is computed as the number of vehicles halting divided by the number of vehicles that could fit in the lane.
        """
        lanes_queue = [
            self.sumo.lane.getLastStepHaltingNumber(lane)
            / (self.lanes_length[lane] / (self.MIN_GAP + self.sumo.lane.getLastStepLength(lane)))
            for lane in self.lanes
        ]
        return [min(1, queue) for queue in lanes_queue]

    def get_total_queued(self) -> int:
        """Returns the total number of vehicles halting in the intersection."""
        return sum(self.sumo.lane.getLastStepHaltingNumber(lane) for lane in self.lanes)

    def _get_veh_list(self):
        veh_list = []
        for lane in self.lanes:
            veh_list += self.sumo.lane.getLastStepVehicleIDs(lane)
        return veh_list

    @classmethod
    def register_reward_fn(cls, fn: Callable):
        """Registers a reward function.

        Args:
            fn (Callable): The reward function to register.
        """
        if fn.__name__ in cls.reward_fns.keys():
            raise KeyError(f"Reward function {fn.__name__} already exists")

        cls.reward_fns[fn.__name__] = fn

    reward_fns = {
        "diff-waiting-time": _diff_waiting_time_reward,
        "average-speed": _average_speed_reward,
        "queue": _queue_reward,
        "pressure": _pressure_reward,
        "waiting-time-road": _waiting_time_reward_per_road,
        "waiting-time-lane": _waiting_time_reward_per_lane,
        "waiting-time-road-big": _waiting_time_reward_per_road_big,
        "waiting-time-road-big-nonlinear": _waiting_time_reward_per_road_big_nonlinear,
    }
