"""
Run the Fixed-Time model
On JiNan and HangZhou real data
"""
from utils.utils import pipeline_wrapper, merge
import os
import time
from multiprocessing import Process
from utils import config
import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--memo",       type=str,               default='benchmark_1001')
    parser.add_argument("-model",       type=str,               default="Fixedtime")
    parser.add_argument("-multi_process", action="store_true",  default=True)
    parser.add_argument("-workers",     type=int,               default=3)
    parser.add_argument("-gen",        type=int,            default=1)
    parser.add_argument("-dataset",     type=str,   choices=['jinan', 'hangzhou', 'newyork'],    default='newyork')

    return parser.parse_args()


def main(in_args):
    if in_args.dataset == 'hangzhou':
        count = 3600
        road_net = "4_4"
        traffic_file_list = ["anon_4_4_hangzhou_real_5816.json"]
        template = "Hangzhou"
        num_rounds = 1
    elif in_args.dataset == 'jinan':
        count = 3600
        road_net = "3_4"
        traffic_file_list = ["anon_3_4_jinan_real_2500.json"]
        template = "Jinan"
        num_rounds = 1
    elif in_args.dataset == 'newyork':
        count = 3600
        road_net = "28_7"
        traffic_file_list = ["anon_28_7_newyork_real_double.json"]
        template = "NewYork"
        num_rounds = 1

    NUM_ROW = int(road_net.split('_')[0])
    NUM_COL = int(road_net.split('_')[1])
    num_intersections = NUM_ROW * NUM_COL
    print('num_intersections:', num_intersections)
    print(traffic_file_list)
    process_list = []
    for traffic_file in traffic_file_list:
        dic_traffic_env_conf_extra = {
            "NUM_ROUNDS": num_rounds,
            "NUM_GENERATORS": in_args.gen,
            "NUM_AGENTS": num_intersections,
            "NUM_INTERSECTIONS": num_intersections,
            "MODEL_NAME": in_args.model,
            "RUN_COUNTS": count,
            "NUM_ROW": NUM_ROW,
            "NUM_COL": NUM_COL,
            "TRAFFIC_FILE": traffic_file,
            "ROADNET_FILE": "roadnet_{0}.json".format(road_net),
            "LIST_STATE_FEATURE": [
                "cur_phase",
                "time_this_phase",
                "traffic_movement_pressure_queue",
            ],
            "DIC_REWARD_INFO": {
                "traffic_movement_pressure_queue_efficient": -0.25,
            },
        }

        dic_agent_conf_extra = {
            "FIXED_TIME": [30, 30, 30, 30],
        }

        dic_path_extra = {
            "PATH_TO_MODEL": os.path.join("model", in_args.memo, traffic_file + "_" +
                                          time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time())) + "_FT"),
            "PATH_TO_WORK_DIRECTORY": os.path.join("records", in_args.memo, traffic_file + "_" +
                                                   time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time())) + "_FT"),
            "PATH_TO_DATA": os.path.join("data", template, str(road_net))
        }
        deploy_dic_agent_conf = merge(config.DIC_BASE_AGENT_CONF, dic_agent_conf_extra)
        deploy_dic_traffic_env_conf = merge(config.dic_traffic_env_conf, dic_traffic_env_conf_extra)
        deploy_dic_path = merge(config.DIC_PATH, dic_path_extra)
        if in_args.multi_process:
            ppl = Process(target=pipeline_wrapper,
                          args=(deploy_dic_agent_conf,
                                deploy_dic_traffic_env_conf,
                                deploy_dic_path, False))
            process_list.append(ppl)
        else:
            # pipeline_wrapper(dic_agent_conf=deploy_dic_agent_conf,
            #                  dic_traffic_env_conf=deploy_dic_traffic_env_conf,
            #                  dic_path=deploy_dic_path)
            pipeline_wrapper(dic_agent_conf=deploy_dic_agent_conf,
                             dic_traffic_env_conf=deploy_dic_traffic_env_conf,
                             dic_path=deploy_dic_path,
                             offline = False)

    if in_args.multi_process:
        list_cur_p = []
        for p in process_list:
            if len(list_cur_p) < in_args.workers:
                p.start()
                list_cur_p.append(p)
            if len(list_cur_p) < in_args.workers:
                continue

        for p in list_cur_p:
            p.join()


if __name__ == "__main__":
    args = parse_args()
    main(args)
