import os
import time
from utils.motsc.motsc import MOTSCFakeEnv
from utils.tester import tester
from utils.utils import path_check, copy_conf_file, copy_cityflow_file

traffic_file = "anon_3_4_jinan_real_2500.json"
template = "Jinan"
road_net = "3_4"
offline_data = "Fixedtime"

dic_traffic_env_conf = {
    "NUM_INTERSECTIONS": 12,
    "RUN_COUNTS": 3600,
    "YELLOW_TIME": 5,
    "NUM_PHASES": 4,
    "MIN_ACTION_TIME": 60,
    "LIST_STATE_FEATURE": [
        # "lane_num_vehicle",
        "lane_num_waiting_vehicle_in",
        # "traffic_movement_pressure_queue",
        "traffic_movement_pressure_queue_efficient",
        "lane_enter_running_part",
    ],
    "DIC_REWARD_INFO": {
        "lane_num_waiting_vehicle_in": -0.25,
        "traffic_movement_pressure_queue": 0,
    },
    "PHASE": {
        1: [0, 1, 0, 1, 0, 0, 0, 0],
        2: [0, 0, 0, 0, 0, 1, 0, 1],
        3: [1, 0, 1, 0, 0, 0, 0, 0],
        4: [0, 0, 0, 0, 1, 0, 1, 0]
    },
    "NUM_LANE": 12,
    "PHASE_MAP": [[1, 4, 12, 13, 14, 15, 16, 17], [7, 10, 18, 19, 20, 21, 22, 23], [0, 3, 18, 19, 20, 21, 22, 23], [6, 9, 12, 13, 14, 15, 16, 17]],
    "FORGET_ROUND": 20,
    "MODEL_NAME": None,
    "NUM_ROW": 3,
    "NUM_COL": 4,
    "TOP_K_ADJACENCY": 5,

    "ACTION_PATTERN": "set",

    "OBS_LENGTH": 167,
    "MEASURE_TIME": 60,

    "BINARY_PHASE_EXPANSION": True,

    "ALL_RED_TIME": 0,
    "NUM_LANES": [3, 3, 3, 3],

    "INTERVAL": 1,

    "list_lane_order": ["WL", "WT", "EL", "ET", "NL", "NT", "SL", "ST"],
    "PHASE_LIST": ['WT_ET', 'NT_ST', 'WL_EL', 'NL_SL'],
    "TRAFFIC_FILE": traffic_file,
    "ROADNET_FILE": "roadnet_{0}.json".format(road_net),
    "NUM_AGENTS": 12
}

dic_path = {
    "PATH_TO_MODEL": os.path.join(os.getcwd(), "model", traffic_file + "_"
                                  + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time())) + "_MOTSC"),
    "PATH_TO_WORK_DIRECTORY": os.path.join(os.getcwd(), "records", "test_transition", traffic_file + "_"
                                           + time.strftime('%m_%d_%H_%M_%S', time.localtime(time.time())) + "_MOTSC"),
    "PATH_TO_DATA": os.path.join(os.getcwd(), "data", template, str(road_net)),
    "PATH_TO_OFFLINE_DATA": os.path.join(os.getcwd(), "offline_dataset", offline_data, traffic_file)
}

dic_agent_conf = {
    "D_DENSE": 20,
    "LEARNING_RATE": 0.001,
    "PATIENCE": 10,
    "BATCH_SIZE": 20,
    "EPOCHS": 100,
    "SAMPLE_SIZE": 3000,
    "MAX_MEMORY_LEN": 12000,

    "UPDATE_Q_BAR_FREQ": 5,
    "UPDATE_Q_BAR_EVERY_C_ROUND": False,

    "GAMMA": 0.8,
    "NORMAL_FACTOR": 20,

    "EPSILON": 0.8,
    "EPSILON_DECAY": 0.95,
    "MIN_EPSILON": 0.2,
    "LOSS_FUNCTION": "mean_squared_error",
}

path_check(dic_path)
copy_conf_file(dic_path, dic_agent_conf, dic_traffic_env_conf)
copy_cityflow_file(dic_path, dic_traffic_env_conf)
test = tester(dic_agent_conf, dic_path, dic_traffic_env_conf)
# test.test()
# test.draw_line("iter", "next_state", [0], "lane_num_waiting_vehicle_in")
test.test_single_step()
test.draw_scatter("step", "next_state", [0], "lane_num_waiting_vehicle_in")
test.test()
test.draw_line("iter", "next_state", [0], "lane_num_waiting_vehicle_in")
