from collections import defaultdict

import numpy as np

from agent import BaseAgent


class GPL(BaseAgent):
    """
    Liao, Xiao-Cheng, Yi Mei, and Mengjie Zhang.
    "Learning traffic signal control via genetic programming."
    In Proceedings of the Genetic and Evolutionary Computation Conference, pp. 924-932. 2024.
    """

    def __init__(self, config, env, idx):
        super(GPL, self).__init__(config, env, idx)
        self.t_min = 20  # the minimum duration of one phase
        self.next_phase_time = self.t_min

        self.eff_range = 180

        self.func = lambda *args: args[0] - sum(args[1:4]) + args[4] - sum(args[5:8])

    def get_phase_num_features(self):
        nums = []
        for phase_id in range(0, self.num_phase):
            n_available_lane_link = self.inter.n_phase[phase_id].n_available_lanelink_id
            TM_mappings = defaultdict(lambda: [])
            for lane_link in n_available_lane_link:
                start_lane_idx, end_lane_idx = lane_link[0], lane_link[1]
                if start_lane_idx.endswith('2'):
                    continue
                TM_mappings[start_lane_idx].append(end_lane_idx)
            lane_list = []
            for k, v in TM_mappings.items():
                lane_list.append(k)
                lane_list.extend(v)
            num = len(lane_list)
            nums.append(num)
        assert len(set(nums)) == 1
        return nums[0] * 2

    def reset(self):
        self.current_phase = 0
        self.next_phase_time = self.t_min

    def pick_action(self, n_obs, on_training):
        obs = n_obs[self.idx]
        # if self.inter.current_phase_time < self.next_phase_time:
        #     return self.current_phase

        tmps = [self._get_pressure_for_phase(obs, phase_id) for phase_id in range(0, self.num_phase)]
        next_phase = np.argmax(tmps)
        self.current_phase = next_phase
        return self.current_phase

    def _get_pressure_for_phase(self, obs, phase_id):
        n_available_lane_link = self.inter.n_phase[phase_id].n_available_lanelink_id

        TM_mappings = defaultdict(lambda: [])
        for lane_link in n_available_lane_link:
            start_lane_idx, end_lane_idx = lane_link[0], lane_link[1]
            if start_lane_idx.endswith('2'):
                continue
            TM_mappings[start_lane_idx].append(end_lane_idx)
        if len(TM_mappings) == 0:
            return -float('inf')
        lane_list = []
        for k, v in TM_mappings.items():
            lane_list.append(k)
            lane_list.extend(sorted(v, key=lambda x: int(x[-1])))
        lane_list_idx = [self.inter.n_lane_id.index(lane) for lane in lane_list]

        pressure = self.func(*[obs[1][idx] for idx in lane_list_idx], *[obs[0][idx] for idx in lane_list_idx])

        return pressure
