import numpy as np

from agent import BaseAgent


def sigmoid(x):
    return 0.5 * (1 + np.tanh(0.5 * x)) + 1e-6


class SymLight(BaseAgent):
    """
    SymLight
    """

    def __init__(self, config, env, idx):
        super(SymLight, self).__init__(config, env, idx)
        self.t_min = 20  # the minimum duration of one phase
        self.next_phase_time = self.t_min

        self.eff_range = 180

        # self.func = lambda wi, wo, ci, co, di, do: wi - wo
        self.func = None
        self.shared_roadlinks = self._get_shared_road_link()
        self.idx_road_mapping = { road.road_id: road for road in self.inter.n_road }
        self.max_capacity = sum([road.length for road in self.inter.n_road])/7.5


    def _get_shared_road_link(self):
        return set.intersection(*map(set, [self.inter.n_phase[i].n_available_roadlink_idx
                                           for i in range(self.num_phase)]))

    def _assert(self):
        controlled_road_links_id = set.union(*map(set, [phase.n_available_roadlink_idx for phase in self.inter.n_phase]))
        assert len(controlled_road_links_id) == len(self.inter.n_roadlink)

    def reset(self):
        self.current_phase = 0
        self.next_phase_time = self.t_min

    def pick_action(self, n_obs, on_training):

        obs = n_obs[self.idx]
        num_move = len(self.inter.n_roadlink)
        move_values = np.zeros(num_move)
        for move_id in range(num_move):
            move_values[move_id] = self._get_value_for_move2(obs, move_id)
        phase_values = self._aggregate_for_each_phase(move_values)

        next_phase = phase_values.argmax()

        if self.current_phase == next_phase:
            self.next_phase_time += self.t_min
        else:
            self.next_phase_time = self.t_min

        self.current_phase = next_phase
        return self.current_phase

    def _get_value_for_move2(self, obs, move_id):
        TOTAL_NUM = max(1, sum(obs[2]) + sum(obs[3]))
        road_link = self.inter.n_roadlink[move_id]
        start_road = self.idx_road_mapping[road_link.startroad_id]
        end_road = self.idx_road_mapping[road_link.endroad_id]
        TOTAL_LEN = start_road.length + end_road.length

        ret = 0.0
        for lane_link in road_link.n_lanelink_id:
            start_lane_name, end_lane_name = lane_link[0], lane_link[1]
            start = self.inter.n_in_lane_id.index(start_lane_name)
            end = self.inter.n_out_lane_id.index(end_lane_name)
            wi = obs[0][start]
            wo = obs[1][end]
            ci = obs[2][start]
            co = obs[3][end]


            di = (obs[4][start] < self.eff_range).sum()
            do = (obs[5][end] < self.eff_range).sum()

            li = 7.5 * ci / TOTAL_LEN
            lo = 7.5 * co / TOTAL_LEN

            args = np.array([wi, wo, ci, co, di, do]) / TOTAL_NUM
            ret += self.func(*args, li, lo)

        return ret

    def _aggregate_for_each_phase(self, move_values):
        phase_values = np.zeros(self.num_phase)
        for phase_id in range(self.num_phase):
            n_roadlink_idx = [road_link for road_link in self.inter.n_phase[phase_id].n_available_roadlink_idx
                              if road_link not in self.shared_roadlinks]
            if len(n_roadlink_idx):
                phase_values[phase_id] = move_values[n_roadlink_idx].sum()
            else:
                phase_values[phase_id] = -float('inf')
        return phase_values
