from .config import DIC_AGENTS
import time
import os
import copy
import numpy as np
import pickle
from utils.tester import tester

class Generator:
    def __init__(self, cnt_round, cnt_gen, dic_path, dic_agent_conf, dic_traffic_env_conf, offline=False, env=None):

        self.cnt_round = cnt_round
        self.cnt_gen = cnt_gen
        self.dic_path = dic_path
        self.dic_agent_conf = copy.deepcopy(dic_agent_conf)
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.agents = [None]*dic_traffic_env_conf['NUM_AGENTS']
        self.path_to_log = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "train_round",
                                        "round_"+str(self.cnt_round), "generator_"+str(self.cnt_gen))
        self.offline = offline
        self.logger = logger(self.dic_path["PATH_TO_WORK_DIRECTORY"], self.dic_traffic_env_conf["NUM_INTERSECTIONS"])

        self.env = env
        
        if not os.path.exists(self.path_to_log):
            os.makedirs(self.path_to_log)
            start_time = time.time()
            for i in range(dic_traffic_env_conf['NUM_AGENTS']):
                agent_name = self.dic_traffic_env_conf["MODEL_NAME"]
                agent = DIC_AGENTS[agent_name](
                    dic_agent_conf=self.dic_agent_conf,
                    dic_traffic_env_conf=self.dic_traffic_env_conf,
                    dic_path=self.dic_path,
                    cnt_round=self.cnt_round,
                    intersection_id=str(i)
                )
                self.agents[i] = agent
            print("Create intersection agent time: ", time.time()-start_time)
            
        self.logger.update_path_to_log(cnt_round, cnt_gen)
                

    def generate(self):
        reset_env_start_time = time.time()
        done = False
        state = self.env.reset()
        step_num = 0
        reset_env_time = time.time() - reset_env_start_time
        running_start_time = time.time()
        prev_reward = np.array([0.0 for _ in range(self.dic_traffic_env_conf["NUM_INTERSECTIONS"])])
        while not done and step_num < int(self.dic_traffic_env_conf["RUN_COUNTS"] /
                                          self.dic_traffic_env_conf["MIN_ACTION_TIME"]):
            action_list = []
            step_start_time = time.time()
            for i in range(self.dic_traffic_env_conf["NUM_AGENTS"]):

                if self.dic_traffic_env_conf["MODEL_NAME"] in ["EfficientPressLight", "EfficientColight",
                                                               "EfficientMPLight", "Attend",
                                                               "AdvancedMPLight", "AdvancedColight", "AdvancedDQN"]:
                    one_state = state
                    action = self.agents[i].choose_action(step_num, one_state)
                    action_list = action
                else:
                    one_state = state[i]
                    action = self.agents[i].choose_action(step_num, one_state)
                    action_list.append(action)
                    
            state, action_list = np.array(state), np.array(action_list)
            current_time = self.env.get_current_time()
            next_state, reward = self.env.step(action_list)
            self.logger.log(current_time, state, action_list, prev_reward)
            for i in range(1, self.dic_traffic_env_conf["MIN_ACTION_TIME"]):
                self.logger.log(current_time + i, next_state, action_list, reward)
            
            prev_reward = reward
            
            print("time: {0}, running_time: {1}".format(self.env.get_current_time() -
                                                        self.dic_traffic_env_conf["MIN_ACTION_TIME"],
                                                        time.time()-step_start_time))

            state = next_state
            step_num += 1
            
        if (self.cnt_round) % 10 == 0:
            test = tester(self.dic_agent_conf, self.dic_path, self.dic_traffic_env_conf, fake_env=self.env, agents=self.agents, round=self.cnt_round)
            test.test_single_step()
            test.draw_scatter("step", "next_state", [0], "lane_num_waiting_vehicle_in")
            test.draw_line("step", "next_state", [0], "lane_num_waiting_vehicle_in")
            test.test()
            test.draw_line("iter", "next_state", [0], "lane_num_waiting_vehicle_in")
            
        running_time = time.time() - running_start_time
        log_start_time = time.time()
        print("start logging.......................")
        self.logger.batch_log()   # sampling, storing in ./inter_{}.pkl
        log_time = time.time() - log_start_time
        print("reset_env_time: ", reset_env_time)
        print("running_time: ", running_time)
        print("log_time: ", log_time)

            
class logger():
    def __init__(self, path_to_work_dir, num_inters):
        self.path_to_log = None
        self.path_to_work_dir = path_to_work_dir
        self.num_inters = num_inters
        self.inter_log = {}
        for inter_id in range(self.num_inters):
            self.inter_log[inter_id] = []
        
    def set_path_to_log(self, path_to_log):
        self.path_to_log = path_to_log
        
    def update_path_to_log(self, cnt_round, cnt_gen):
        self.path_to_log = os.path.join(self.path_to_work_dir, "train_round",
                                        "round_"+str(cnt_round), "generator_"+str(cnt_gen))
        if not os.path.exists(self.path_to_log):
            os.makedirs(self.path_to_log)
    
    def log(self, cur_time, state_before_action, action, reward):
        for inter_ind in range(self.num_inters):
            self.inter_log[inter_ind].append({"time": cur_time,
                                              "state": state_before_action[inter_ind],
                                              "action": action[inter_ind]})
            
    def batch_log(self):
        for inter_id in range(self.num_inters):
            path_to_log_file = os.path.join(self.path_to_log, "inter_{0}.pkl".format(inter_id))
            with open(path_to_log_file, "wb") as f:
                pickle.dump(self.inter_log[inter_id], f)