from utils.offline_datafactory import DataFactory
from utils.utils import pipeline_wrapper, merge
from utils import config
from utils.cityflow import CityFlowEnv
from utils.model_test import cityflow_logger
from utils.travel_time import travel_time_metric
import time
from multiprocessing import Process
import argparse
import os


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-memo",       type=str,           default='benchmark_1001')
    parser.add_argument("-mod",        type=str,           default="SOTL")
    parser.add_argument("-eightphase",  action="store_true", default=False)
    parser.add_argument("-gen",        type=int,            default=1)
    parser.add_argument("-multi_process", action="store_true", default=True)
    parser.add_argument("-workers",    type=int,            default=3)
    parser.add_argument("-model_round",    type=int,         default=0)
    parser.add_argument("-record_name",    type=str,         default="")
    parser.add_argument("-dataset",     type=str,   choices=['jinan', 'hangzhou', 'newyork'],    default='jinan')
    return parser.parse_args()


def main(in_args=None):

    if in_args.dataset == 'hangzhou':
        count = 3600 * 4
        road_net = "4_4"
        # traffic_file_list = ["anon_4_4_hangzhou_real.json", 
        #                      "anon_4_4_hangzhou_real_5816.json"]
        traffic_file_list = ["anon_4_4_hangzhou_real_5816.json"]
        template = "Hangzhou"
        num_rounds = 1
    elif in_args.dataset == 'jinan':
        count = 3600 * 4
        road_net = "3_4"
        # traffic_file_list = ["anon_3_4_jinan_real.json", "anon_3_4_jinan_real_2000.json",
        #                      "anon_3_4_jinan_real_2500.json"]
        traffic_file_list = ["anon_3_4_jinan_real_2500.json"]
        template = "Jinan"
        num_rounds = 1
    elif in_args.dataset == 'newyork':
        count = 3600 * 4
        road_net = "28_7"
        traffic_file_list = ["anon_28_7_newyork_real_double.json"]
        template = "NewYork"
        num_rounds = 1

    NUM_COL = int(road_net.split('_')[1])
    NUM_ROW = int(road_net.split('_')[0])
    num_intersections = NUM_ROW * NUM_COL
    print('num_intersections:', num_intersections)
    print(traffic_file_list)
    process_list = []
    for traffic_file in traffic_file_list:
        dic_traffic_env_conf_extra = {
            "CITYFLOW_MAX_TIME": 3600,
            "NUM_ROUNDS": num_rounds,
            "NUM_GENERATORS": in_args.gen,
            "NUM_AGENTS": 1,
            "NUM_INTERSECTIONS": num_intersections,
            "RUN_COUNTS": in_args.model_round + 1,
            "MODEL_NAME": in_args.mod,
            "NUM_ROW": NUM_ROW,
            "NUM_COL": NUM_COL,
            "TRAFFIC_FILE": traffic_file,
            "ROADNET_FILE": "roadnet_{0}.json".format(road_net),
            "TRAFFIC_SEPARATE": traffic_file,
            "LIST_STATE_FEATURE": [
                "lane_num_waiting_vehicle_in",
                "lane_enter_running_part",
                "traffic_movement_pressure_queue_efficient"
            ],
            "DIC_REWARD_INFO": {
                "traffic_movement_pressure_queue_efficient": -0.25
            },
            "LOG_REWARD_INFO": {
                "traffic_movement_pressure_queue_efficient": -0.25,
                "lane_num_waiting_vehicle_in": -0.25,
            },
            "LIST_INFO_FEATURE": [
                "lane_num_waiting_vehicle_in",
            ]
        }

        if in_args.eightphase:
            dic_traffic_env_conf_extra["PHASE"] = {
                1: [0, 1, 0, 1, 0, 0, 0, 0],
                2: [0, 0, 0, 0, 0, 1, 0, 1],
                3: [1, 0, 1, 0, 0, 0, 0, 0],
                4: [0, 0, 0, 0, 1, 0, 1, 0],
                5: [1, 1, 0, 0, 0, 0, 0, 0],
                6: [0, 0, 1, 1, 0, 0, 0, 0],
                7: [0, 0, 0, 0, 0, 0, 1, 1],
                8: [0, 0, 0, 0, 1, 1, 0, 0]
            }
            dic_traffic_env_conf_extra["PHASE_LIST"] = ['WT_ET', 'NT_ST', 'WL_EL', 'NL_SL',
                                                        'WL_WT', 'EL_ET', 'SL_ST', 'NL_NT']

        dic_path_extra = {
            "PATH_TO_WORK_DIRECTORY": os.path.join(os.getcwd(), "records", in_args.memo, in_args.record_name),
            "PATH_TO_DATA": os.path.join(os.getcwd(), "data", template, str(road_net)),
            "PATH_TO_ERROR": os.path.join(os.getcwd(), "errors", in_args.memo),
            "PATH_TO_MODEL": os.path.join(os.getcwd(), "model", in_args.memo, in_args.record_name)
        }
        
        dic_agent_conf_extra = {
            "FIXED_TIME": [30, 30, 30, 30],
        }

        deploy_dic_agent_conf = merge(config.DIC_BASE_AGENT_CONF, dic_agent_conf_extra)
        deploy_dic_traffic_env_conf = merge(config.dic_traffic_env_conf, dic_traffic_env_conf_extra)
        deploy_dic_path = merge(config.DIC_PATH, dic_path_extra)

        agents = []
        for i in range(deploy_dic_traffic_env_conf['NUM_AGENTS']):
            agent_name = deploy_dic_traffic_env_conf["MODEL_NAME"]
            agent = config.DIC_AGENTS[agent_name](
                dic_agent_conf=deploy_dic_agent_conf,
                dic_traffic_env_conf=deploy_dic_traffic_env_conf,
                dic_path=deploy_dic_path,
                cnt_round=deploy_dic_traffic_env_conf["RUN_COUNTS"],
                intersection_id=str(i)
            )
            agents.append(agent)

        # for i in range(deploy_dic_traffic_env_conf['NUM_AGENTS']):
        #     agents[i].load_network("round_{0}_inter_{1}".format(in_args.model_round, agents[i].intersection_id))
        
        path_to_log = os.path.join(deploy_dic_path["PATH_TO_WORK_DIRECTORY"], "test_round", "round_{}".format(in_args.model_round))
        logger = cityflow_logger(path_to_log, deploy_dic_traffic_env_conf["NUM_INTERSECTIONS"])
        env = CityFlowEnv(
            path_to_work_directory=deploy_dic_path["PATH_TO_WORK_DIRECTORY"],
            dic_traffic_env_conf=deploy_dic_traffic_env_conf
        )
        metric = travel_time_metric(env)
        state = env.reset()
        step_num = 0
        while step_num < int(deploy_dic_traffic_env_conf["CITYFLOW_MAX_TIME"] / deploy_dic_traffic_env_conf["MIN_ACTION_TIME"]):
            action_list = []
            for i in range(deploy_dic_traffic_env_conf["NUM_AGENTS"]):
                if deploy_dic_traffic_env_conf["MODEL_NAME"] in ["EfficientPressLight", "EfficientColight",
                                                                       "EfficientMPLight", "Attend",
                                                                       "AdvancedMPLight", "AdvancedColight", "AdvancedDQN"]:
                    one_state = state
                    action = agents[i].choose_action(step_num, one_state)
                    action_list = action
                else:
                    one_state = state[i]
                    action = agents[i].choose_action(step_num, one_state)
                    action_list.append(action)

            next_state, reward = env.step(action_list)
            metric.update()
            # logger.log_reward(reward)
            logger.log_reward_info(env.get_reward_info(deploy_dic_traffic_env_conf["LOG_REWARD_INFO"]))
            state = next_state
            step_num += 1
            
        metric.update(done=True)
        metric.log_travel_time(path_to_log)
        for feature in deploy_dic_traffic_env_conf["LOG_REWARD_INFO"].keys():
            logger.batch_log_reward(feature)
        
    return in_args.memo


if __name__ == "__main__":
    args = parse_args()

    main(args)

