from utils.config import DIC_AGENTS
from utils.cityflow import CityFlowEnv
from utils.construct_sample import OfflineConstructSample
from utils.generator import logger
from utils.pipeline import path_check, copy_cityflow_file, copy_conf_file
import os
import time
import numpy as np


class DataFactory:

    def __init__(self, dic_agent_conf, dic_traffic_env_conf, dic_path):
        self.dic_agent_conf = dic_agent_conf
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.dic_path = dic_path
        self.agents = []
        self.env = None
        self.logger = None
        self.path_to_log = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"])
        self.initialize()

    def initialize(self):
        # path_check(self.dic_path)
        copy_conf_file(self.dic_path, self.dic_agent_conf, self.dic_traffic_env_conf)
        copy_cityflow_file(self.dic_path, self.dic_traffic_env_conf)

        for i in range(self.dic_traffic_env_conf['NUM_INTERSECTIONS']):
            agent_name = self.dic_traffic_env_conf["MODEL_NAME"]
            agent = DIC_AGENTS[agent_name](
                dic_agent_conf=self.dic_agent_conf,
                dic_traffic_env_conf=self.dic_traffic_env_conf,
                dic_path=self.dic_path,
                cnt_round=0,
                intersection_id=str(i)
            )
            self.agents.append(agent)
             
        if not os.path.exists(self.path_to_log):
            os.makedirs(self.path_to_log)
            
        if self.dic_traffic_env_conf["MODEL_NAME"] == "MaxPressure":
            self.dic_traffic_env_conf["MIN_ACTION_TIME"] = 30
        
        self.env = CityFlowEnv(
            path_to_work_directory=self.dic_path["PATH_TO_WORK_DIRECTORY"],
            dic_traffic_env_conf=self.dic_traffic_env_conf
        )
        self.logger = logger(self.path_to_log, self.dic_traffic_env_conf["NUM_INTERSECTIONS"])
        self.logger.set_path_to_log(self.path_to_log)

    def train(self):
        print("================ start train ================")
        total_run_cnt = self.dic_traffic_env_conf["RUN_COUNTS"]
        # initialize output streams
        file_name_memory = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "memories.txt")
        done = False
        state = self.env.reset()
        print("end reset")
        env_current_time = self.env.get_current_time()  # in seconds
        prev_reward = np.array([0.0 for _ in range(self.dic_traffic_env_conf["NUM_INTERSECTIONS"])])
        start_time = time.time()
        skip_early = self.dic_traffic_env_conf["MEASURE_TIME"]
        action_list = []
        for total_time in range(0, total_run_cnt, self.dic_traffic_env_conf["MEASURE_TIME"]):
            if env_current_time > 0 and env_current_time % self.dic_traffic_env_conf["CITYFLOW_MAX_TIME"] == 0:
                state = self.env.reset()
                env_current_time = self.env.get_current_time()  # in seconds
                skip_early += self.dic_traffic_env_conf["MEASURE_TIME"] * 10

            if env_current_time >= skip_early or len(action_list) == 0:
                action_list = []
                for i in range(len(state)):
                    one_state = state[i]
                    action = self.agents[i].choose_action(total_time, one_state)
                    action_list.append(action)
            next_state, reward = self.env.step(action_list)
            self.logger.log(env_current_time, state, action_list, prev_reward)
            for i in range(1, self.dic_traffic_env_conf["MEASURE_TIME"]):
                self.logger.log(env_current_time + i, next_state, action_list, reward)
            
            prev_reward = reward
            
            f_memory = open(file_name_memory, "a")
            # output to std out and file
            memory_str = 'time = {0}\taction = {1}\tcurrent_phase = {2}\treward = {3}'.\
                format(env_current_time, str(action_list), str([state[i]["cur_phase"][0] for i in range(len(state))]),
                       str(reward),)
            f_memory.write(memory_str + "\n")
            f_memory.close()
            env_current_time = self.env.get_current_time()  # in seconds

            state = next_state

        self.logger.batch_log()
        
        path_to_data = self.dic_path["PATH_TO_WORK_DIRECTORY"] + '/data'
        if not os.path.exists(path_to_data):
            os.makedirs(path_to_data)
        cs = OfflineConstructSample(path_to_samples=self.path_to_log, path_to_data=path_to_data,
                                    cnt_round=0, dic_traffic_env_conf=self.dic_traffic_env_conf)
        cs.make_reward_for_system()        
        print("Sampling time: {}".format(time.time() - start_time))

