from utils.config import DIC_AGENTS
from copy import deepcopy
from utils.cityflow import CityFlowEnv
from utils.travel_time import travel_time_metric
import json
import os
import pickle
import pandas as pd


def test(model_dir, cnt_round, run_cnt, _dic_traffic_env_conf, dic_path):
    dic_traffic_env_conf = deepcopy(_dic_traffic_env_conf)
    records_dir = model_dir.replace("model", "records")
    model_round = "round_%d" % cnt_round
    with open(os.path.join(records_dir, "agent.conf"), "r") as f:
        dic_agent_conf = json.load(f)
    if os.path.exists(os.path.join(records_dir, "anon_env.conf")):
        with open(os.path.join(records_dir, "anon_env.conf"), "r") as f:
            dic_traffic_env_conf = json.load(f)
    dic_traffic_env_conf["RUN_COUNTS"] = run_cnt

    if dic_traffic_env_conf["MODEL_NAME"] in dic_traffic_env_conf["LIST_MODEL_NEED_TO_UPDATE"]:
        dic_agent_conf["EPSILON"] = 0
        dic_agent_conf["MIN_EPSILON"] = 0

    agents = []
    for i in range(dic_traffic_env_conf['NUM_AGENTS']):
        agent_name = dic_traffic_env_conf["MODEL_NAME"]
        agent = DIC_AGENTS[agent_name](
            dic_agent_conf=dic_agent_conf,
            dic_traffic_env_conf=dic_traffic_env_conf,
            dic_path=dic_path,
            cnt_round=0,
            intersection_id=str(i)
        )
        agents.append(agent)

    if dic_traffic_env_conf["MODEL_NAME"] in dic_traffic_env_conf["LIST_MODEL_NEED_TO_UPDATE"]:
        for i in range(dic_traffic_env_conf['NUM_AGENTS']):
            agents[i].load_network("{0}_inter_{1}".format(model_round, agents[i].intersection_id))
    path_to_log = os.path.join(dic_path["PATH_TO_WORK_DIRECTORY"], "test_round", model_round)
    if not os.path.exists(path_to_log):
        os.makedirs(path_to_log)
    env = CityFlowEnv(
        path_to_work_directory=dic_path["PATH_TO_WORK_DIRECTORY"],
        dic_traffic_env_conf=dic_traffic_env_conf
    )
    metric = travel_time_metric(env)
    logger = cityflow_logger(path_to_log, dic_traffic_env_conf["NUM_INTERSECTIONS"])
    done = False
    step_num = 0
    total_time = dic_traffic_env_conf["RUN_COUNTS"]
    state = env.reset()
    while not done and step_num < int(total_time / dic_traffic_env_conf["MIN_ACTION_TIME"]):
        action_list = []
        for i in range(dic_traffic_env_conf["NUM_AGENTS"]):
            if dic_traffic_env_conf["MODEL_NAME"] in ["EfficientPressLight", "EfficientColight",
                                                                   "EfficientMPLight", "Attend",
                                                                   "AdvancedMPLight", "AdvancedColight", "AdvancedDQN"]:
                one_state = state
                action = agents[i].choose_action(step_num, one_state)
                action_list = action
            else:
                one_state = state[i]
                action = agents[i].choose_action(step_num, one_state)
                action_list.append(action)

        next_state, reward = env.step(action_list)
        logger.log_reward(reward)
        logger.log_reward_info(env.get_reward_info(dic_traffic_env_conf["LOG_REWARD_INFO"]))
        metric.update()
        state = next_state
        step_num += 1
        
    for inter_ind in range(dic_traffic_env_conf["NUM_INTERSECTIONS"]):
        logger.log_vehicle(env.get_dic_vehicle_arrive_leave_time(inter_ind))
    logger.batch_log()
    for feature in dic_traffic_env_conf["LOG_REWARD_INFO"].keys():
        logger.batch_log_reward(feature)
    metric.update(done=True)
    metric.log_travel_time(path_to_log)
    env.end_cityflow()
    
class cityflow_logger():
    def __init__(self, path_to_log, num_inters):
        self.path_to_log = path_to_log
        self.num_inters = num_inters
        self.dic_vehicle_all_inters = []
        self.rewards = {"sys_reward" : []}
    
    def batch_log(self):
        for inter_ind in range(self.num_inters):
            path_to_log_file = os.path.join(self.path_to_log, "vehicle_inter_{0}.csv".format(inter_ind))
            df = pd.DataFrame.from_dict(self.dic_vehicle_all_inters[inter_ind], orient="index")
            df.to_csv(path_to_log_file, na_rep="nan")
        
        self.dic_vehicle_all_inters.clear()
        
        with open(self.path_to_log + "/reward_record.pkl", "wb") as f:
            pickle.dump(self.rewards["sys_reward"], f)
            
        self.rewards["sys_reward"].clear()
    
    def batch_log_reward(self, feature):
        with open(self.path_to_log + "/reward_record_{}.pkl".format(feature), "wb") as f:
            pickle.dump(self.rewards[feature], f)
            
        self.rewards[feature].clear()
            
    def log_reward(self, reward):
        self.rewards["sys_reward"].append(reward)
        
    def log_reward_info(self, reward_info):
        for feature in reward_info.keys():
            if feature not in self.rewards.keys():
                self.rewards[feature] = []
            self.rewards[feature].append(reward_info[feature])
        
    def log_vehicle(self, dic_vehicle):
        self.dic_vehicle_all_inters.append(dic_vehicle)
