"""
Max-Pressure agent.
observation: [traffic_movement_pressure_queue].
Action: use greedy method select the phase with max value.
"""
from .agent import Agent
import numpy as np
import json
import os


class SOTLAgent(Agent):

    def __init__(self, dic_agent_conf, dic_traffic_env_conf, dic_path, cnt_round, intersection_id):

        super(SOTLAgent, self).__init__(dic_agent_conf, dic_traffic_env_conf, dic_path, intersection_id)

        self.current_phase_time = 0
        self.phase_length = len(self.dic_traffic_env_conf["PHASE"])
        with open(os.path.join(self.dic_path["PATH_TO_DATA"], self.dic_traffic_env_conf["ROADNET_FILE"])) as f:
            roadnet = json.load(f)
        self.intersection = None
        NUM_COL = self.dic_traffic_env_conf["NUM_COL"]
        NUM_ROW = self.dic_traffic_env_conf["NUM_ROW"]
        self.intersection_name = "intersection_{}_{}".format(int(intersection_id) // NUM_ROW + 1, int(intersection_id) % NUM_ROW + 1)
        for intersection in roadnet["intersections"]:
            if intersection["id"] == self.intersection_name:
                self.intersection = intersection
                break
        assert self.intersection is not None
        
        self.lanelinks_of_roadlink = []
        self.phase_available_startlanes = []
        self.startlanes = []
        # parsing links and phases
        for roadlink in self.intersection["roadLinks"]:
            lanelinks = []
            for lanelink in roadlink["laneLinks"]:
                startlane = roadlink["startRoad"] + "_" + str(lanelink["startLaneIndex"])
                self.startlanes.append(startlane)
                endlane = roadlink["endRoad"] + "_" + str(lanelink["endLaneIndex"])
                lanelinks.append((startlane, endlane))
            self.lanelinks_of_roadlink.append(lanelinks)
                    
        phases = self.intersection["trafficLight"]["lightphases"]
        
        for phase in self.dic_traffic_env_conf["PHASE"].values():
            phase_ids = [index + 1 for index, item in enumerate(phase) if item == 1]
            phase_available_startlanes = []
            for phase_id in phase_ids:
                for roadlink_id in phases[phase_id]["availableRoadLinks"]:
                    lanelinks_of_roadlink = self.lanelinks_of_roadlink[roadlink_id]
                    for lanelinks in lanelinks_of_roadlink:
                        phase_available_startlanes.append(lanelinks[0])
            phase_available_startlanes = list(set(phase_available_startlanes))
            self.phase_available_startlanes.append(phase_available_startlanes)
            
        self.min_green_vehicle = 100
        self.max_red_vehicle = 150

        self.action = None
        if self.phase_length == 4:
            self.DIC_PHASE_MAP_4 = {  # for 4 phase
                1: 0,
                2: 1,
                3: 2,
                4: 3,
                0: 0
            }
        elif self.phase_length == 8:
            self.DIC_PHASE_MAP = {
                1: 0,
                2: 1,
                3: 2,
                4: 3,
                5: 4,
                6: 5,
                7: 6,
                8: 7,
                0: 0
            }

    def choose_action(self, count, state):
        """
        As described by the definition, use traffic_movement_pressure
        to calcualte the pressure of each phase.
        """
        
        if state["cur_phase"][0] == -1:
            return self.action
        
        # print("state : {} , self : {}".format(state["intersection_name"], self.intersection_name))
        assert state["intersection_name"] == self.intersection_name

        lane_waiting_count = state["dic_lane_num_waiting_vehicle_in"]

        if self.phase_length == 4:
            action = self.DIC_PHASE_MAP_4[state["cur_phase"][0]]
        elif self.phase_length == 8:
            action = self.DIC_PHASE_MAP[state["cur_phase"][0]]
            
        num_green_vehicles = sum([lane_waiting_count[lane] for lane in self.phase_available_startlanes[action]])
        num_red_vehicles = sum([lane_waiting_count[lane] for lane in self.startlanes])
        num_red_vehicles -= num_green_vehicles

        if num_green_vehicles <= self.min_green_vehicle and num_red_vehicles > self.max_red_vehicle:
            action = (action + 1) % self.phase_length
            print("green: {}, red: {}".format(num_green_vehicles, num_red_vehicles))

        return action
