from utils.utils import pipeline_wrapper, merge
from utils import config
import time
import multiprocessing as mp
import argparse
import os


def parse_args():
    parser = argparse.ArgumentParser(
        description="Run the Slight training and evaluation pipeline."
    )
    parser.set_defaults(
        multi_process=True,
        dataset="ny",
    )

    general_group = parser.add_argument_group("General configuration")
    general_group.add_argument(
        "--memo",
        "-memo",
        type=str,
        default="slight_benchmark",
        help="Memo name used to construct experiment directories.",
    )
    general_group.add_argument(
        "--mod",
        "-mod",
        type=str,
        default="Slight",
        help="Model identifier recorded in experiment outputs.",
    )

    exec_group = parser.add_argument_group("Execution settings")
    exec_group.add_argument(
        "--generators",
        "--gen",
        "-gen",
        dest="gen",
        type=int,
        default=1,
        help="Number of data generators to run for each scenario.",
    )
    exec_group.add_argument(
        "--workers",
        "-workers",
        type=int,
        default=3,
        help="Maximum number of worker processes to spawn simultaneously.",
    )
    exec_group.add_argument(
        "--multi-process",
        "--multi_process",
        "-multi_process",
        dest="multi_process",
        action="store_true",
        help="Run each traffic file in a separate process (default).",
    )
    exec_group.add_argument(
        "--single-process",
        dest="multi_process",
        action="store_false",
        help="Run sequentially in the current process.",
    )

    dataset_group = parser.add_argument_group("Dataset selection")
    dataset_options = dataset_group.add_mutually_exclusive_group()
    dataset_options.add_argument(
        "--dataset",
        choices=["hangzhou", "jinan", "custom", "ny"],
        dest="dataset",
        metavar="NAME",
        help="Traffic scenario to run (default: ny).",
    )
    dataset_options.add_argument(
        "--hangzhou",
        "-hangzhou",
        dest="dataset",
        action="store_const",
        const="hangzhou",
        help=argparse.SUPPRESS,
    )
    dataset_options.add_argument(
        "--jinan",
        "-jinan",
        dest="dataset",
        action="store_const",
        const="jinan",
        help=argparse.SUPPRESS,
    )
    dataset_options.add_argument(
        "--custom",
        "-custom",
        dest="dataset",
        action="store_const",
        const="custom",
        help=argparse.SUPPRESS,
    )
    dataset_options.add_argument(
        "--ny",
        "--NY",
        "-NY",
        dest="dataset",
        action="store_const",
        const="ny",
        help=argparse.SUPPRESS,
    )

    slight_group = parser.add_argument_group("Slight options")
    slight_group.add_argument(
        "--num-skills",
        "--num_skills",
        dest="num_skills",
        type=int,
        default=4,
        help="Number of discrete skills managed by the meta controller.",
    )
    slight_group.add_argument(
        "--meta-freq",
        "--meta_freq",
        dest="meta_freq",
        type=int,
        default=10,
        help="Frequency (in steps) for selecting high-level skills.",
    )
    slight_group.add_argument(
        "--cvae-latent",
        "--cvae_latent",
        dest="cvae_latent",
        type=int,
        default=16,
        help="Latent dimensionality used for the CVAE module.",
    )

    embed_group = parser.add_argument_group("Embedding hyperparameters")
    embed_group.add_argument(
        "--lstm-hidden",
        dest="lstm_hidden",
        type=int,
        default=128,
        help="Hidden size for the LSTM agent embedding.",
    )
    embed_group.add_argument(
        "--lstm-layers",
        dest="lstm_layers",
        type=int,
        default=1,
        help="Number of layers for the LSTM agent embedding.",
    )
    embed_group.add_argument(
        "--lstm-dropout",
        dest="lstm_dropout",
        type=float,
        default=0.0,
        help="Dropout for the LSTM agent embedding.",
    )
    embed_group.add_argument(
        "--cvae-hidden",
        dest="cvae_hidden",
        type=int,
        default=128,
        help="Hidden size for the CVAE encoder/decoder.",
    )
    embed_group.add_argument(
        "--cvae-lr",
        dest="cvae_lr",
        type=float,
        default=1e-3,
        help="Learning rate for the CVAE optimizer.",
    )
    embed_group.add_argument(
        "--cvae-kl",
        dest="cvae_kl",
        type=float,
        default=1e-4,
        help="KL weight for the CVAE loss.",
    )

    args = parser.parse_args()

    args.dataset = args.dataset.lower()
    args.hangzhou = args.dataset == "hangzhou"
    args.jinan = args.dataset == "jinan"
    args.custom = args.dataset == "custom"
    args.NY = args.dataset == "ny"

    return args


def main(in_args=None):
    if in_args.hangzhou:
        count = 3600
        road_net = "4_4"
        traffic_file_list = ["anon_4_4_hangzhou_real.json",
                             "anon_4_4_hangzhou_real_5816.json"]
        num_rounds = 50
        template = "Hangzhou"
    elif in_args.jinan:
        count = 3600
        road_net = "3_4"
        traffic_file_list = ["anon_3_4_jinan_real.json", "anon_3_4_jinan_real_2000.json",
                             "anon_3_4_jinan_real_2500.json"]
        num_rounds = 50
        template = "Jinan"
    elif in_args.custom:
        count = 3600
        road_net = "16_16"
        traffic_file_list = ["anon_16_16_custom_sumo_2.json"]
        num_rounds = 50
        template = "SumoConstruct2"
    elif in_args.NY:
        count = 3600
        road_net = "28_7"
        traffic_file_list = ["anon_28_7_newyork_real_double.json"]
        num_rounds = 50
        template = "newyork_28_7"

    NUM_COL = int(road_net.split('_')[1])
    NUM_ROW = int(road_net.split('_')[0])
    num_intersections = NUM_ROW * NUM_COL
    print('num_intersections:', num_intersections)
    print(traffic_file_list)
    process_list = []
    process_ctx = mp.get_context("spawn") if in_args.multi_process else None
    for traffic_file in traffic_file_list:
        dic_agent_conf_extra = {
            "CNN_layers": [[32, 32]],
            # Slight params
            "NUM_SKILLS": in_args.num_skills,
            "META_CONTROLLER_FREQ": in_args.meta_freq,
            "USE_CVAE": True,
            "CVAE_LATENT_DIM": in_args.cvae_latent,
            "AGENT_EMBED_TYPE": "lstm",
            "GROUP_EMBED_TYPE": "cvae",
            "SKILL_Q_TYPE": "colight_gat",
            "LSTM_EMBED_HIDDEN": in_args.lstm_hidden,
            "LSTM_EMBED_LAYERS": in_args.lstm_layers,
            "LSTM_EMBED_DROPOUT": in_args.lstm_dropout,
            "CVAE_HIDDEN_DIM": in_args.cvae_hidden,
            "CVAE_LEARNING_RATE": in_args.cvae_lr,
            "CVAE_KL_WEIGHT": in_args.cvae_kl,
            # optional CVAE training hparams
            "CVAE_BUFFER_SIZE": 2000,
            # Slight agent uses CVAE_BATCH_SIZE (fallback to 64) when sampling memories.
            "CVAE_BATCH_SIZE": 64,
            "CVAE_EPOCHS": 1,
        }
        deploy_dic_agent_conf = merge(getattr(config, "DIC_BASE_AGENT_CONF"), dic_agent_conf_extra)

        dic_traffic_env_conf_extra = {
            "NUM_ROUNDS": num_rounds,
            "NUM_GENERATORS": in_args.gen,
            "NUM_AGENTS": 1,
            "NUM_INTERSECTIONS": num_intersections,
            "RUN_COUNTS": count,
            "MODEL_NAME": in_args.mod,
            "NUM_ROW": NUM_ROW,
            "NUM_COL": NUM_COL,
            "TRAFFIC_FILE": traffic_file,
            "ROADNET_FILE": "roadnet_{0}.json".format(road_net),
            "LIST_STATE_FEATURE": [
                "cur_phase",
                "traffic_movement_pressure_queue_efficient",
                "lane_enter_running_part",
                "adjacency_matrix",
            ],

            "DIC_REWARD_INFO": {
                "queue_length": -0.25,
            },
        }
        dic_path_extra = {
            "PATH_TO_MODEL": os.path.join("model", in_args.memo, traffic_file + "_" +
                                          time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))),
            "PATH_TO_WORK_DIRECTORY": os.path.join("records", in_args.memo, traffic_file + "_"
                                                   + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time()))),
            "PATH_TO_DATA": os.path.join("data", template, str(road_net)),
            "PATH_TO_ERROR": os.path.join("errors", in_args.memo)
        }
        deploy_dic_traffic_env_conf = merge(config.dic_traffic_env_conf, dic_traffic_env_conf_extra)
        deploy_dic_path = merge(config.DIC_PATH, dic_path_extra)

        if in_args.multi_process:
            ppl = process_ctx.Process(target=pipeline_wrapper,
                                      args=(deploy_dic_agent_conf,
                                            deploy_dic_traffic_env_conf,
                                            deploy_dic_path))
            process_list.append(ppl)
        else:
            pipeline_wrapper(dic_agent_conf=deploy_dic_agent_conf,
                             dic_traffic_env_conf=deploy_dic_traffic_env_conf,
                             dic_path=deploy_dic_path)

    if in_args.multi_process:
        for i in range(0, len(process_list), in_args.workers):
            i_max = min(len(process_list), i + in_args.workers)
            for j in range(i, i_max):
                print(j)
                print("start_traffic")
                process_list[j].start()
                print("after_traffic")
            for k in range(i, i_max):
                print("traffic to join", k)
                process_list[k].join()
                print("traffic finish join", k)

    return in_args.memo


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    args = parse_args()
    main(args)
