from utils.motsc.motsc import MOTSCFakeEnv
from utils.mopo.mopo import MOPOFakeEnv
from utils.mo_base.mo_base import MOBASEFakeEnv
from utils.cityflow import CityFlowEnv
from utils.generator import Generator, logger
from utils.construct_sample import ConstructSample
from utils.updater import Updater
from utils import model_test
from utils.tester import tester

import json
import shutil
import os
import time
import copy
from multiprocessing import Process


def path_check(dic_path):
    if os.path.exists(dic_path["PATH_TO_WORK_DIRECTORY"]):
        if dic_path["PATH_TO_WORK_DIRECTORY"] != "records/default":
            raise FileExistsError
        else:
            pass
    else:
        os.makedirs(dic_path["PATH_TO_WORK_DIRECTORY"])
    if os.path.exists(dic_path["PATH_TO_MODEL"]):
        if dic_path["PATH_TO_MODEL"] != "model/default":
            raise FileExistsError
        else:
            pass
    else:
        os.makedirs(dic_path["PATH_TO_MODEL"])


def copy_conf_file(dic_path, dic_agent_conf, dic_traffic_env_conf, path=None):
    if path is None:
        path = dic_path["PATH_TO_WORK_DIRECTORY"]
    json.dump(dic_agent_conf, open(os.path.join(path, "agent.conf"), "w"), indent=4)
    json.dump(dic_traffic_env_conf, open(os.path.join(path, "traffic_env.conf"), "w"), indent=4)


def copy_cityflow_file(dic_path, dic_traffic_env_conf, path=None):
    if path is None:
        path = dic_path["PATH_TO_WORK_DIRECTORY"]
    shutil.copy(os.path.join(dic_path["PATH_TO_DATA"], dic_traffic_env_conf["TRAFFIC_FILE"]),
                os.path.join(path, dic_traffic_env_conf["TRAFFIC_FILE"]))
    shutil.copy(os.path.join(dic_path["PATH_TO_DATA"], dic_traffic_env_conf["ROADNET_FILE"]),
                os.path.join(path, dic_traffic_env_conf["ROADNET_FILE"]))


def generator_wrapper(cnt_round, cnt_gen, dic_path, dic_agent_conf, dic_traffic_env_conf, offline=False, env=None):
    generator = Generator(cnt_round=cnt_round,
                        cnt_gen=cnt_gen,
                        dic_path=dic_path,
                        dic_agent_conf=dic_agent_conf,
                        dic_traffic_env_conf=dic_traffic_env_conf,
                        offline=offline,
                        env=env)
    print("make generator")
    generator.generate()
    print("generator_wrapper end")
    return



def updater_wrapper(cnt_round, dic_agent_conf, dic_traffic_env_conf, dic_path):

    updater = Updater(
        cnt_round=cnt_round,
        dic_agent_conf=dic_agent_conf,
        dic_traffic_env_conf=dic_traffic_env_conf,
        dic_path=dic_path
    )
    updater.load_sample_for_agents()
    updater.update_network_for_agents()
    print("updater_wrapper end")
    return


class Pipeline:

    def __init__(self, dic_agent_conf, dic_traffic_env_conf, dic_path, offline=False):
        self.dic_agent_conf = dic_agent_conf
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.dic_path = dic_path
        self.offline = offline

        self.initialize()

    def initialize(self):
        path_check(self.dic_path)
        copy_conf_file(self.dic_path, self.dic_agent_conf, self.dic_traffic_env_conf)
        copy_cityflow_file(self.dic_path, self.dic_traffic_env_conf)

    def run(self, multi_process=False):
        f_time = open(os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "running_time.csv"), "w")
        f_time.write("generator_time\tmaking_samples_time\tupdate_network_time\ttest_evaluation_times\tall_times\n")
        f_time.close()
        
        '''
            This function starts a round with the length of NUM_ROUNDS:
                -
            
        '''
        
        if "OFFLINE_METHOD" in self.dic_traffic_env_conf.keys():
            if self.dic_traffic_env_conf["OFFLINE_METHOD"] != 'SIMPLE':
                if self.dic_traffic_env_conf["OFFLINE_METHOD"] == 'MOTSC':
                    # env = CityFlowEnv(self.dic_path["PATH_TO_WORK_DIRECTORY"], self.dic_traffic_env_conf)
                    env = MOTSCFakeEnv(self.dic_path, copy.deepcopy(self.dic_traffic_env_conf))
                elif self.dic_traffic_env_conf["OFFLINE_METHOD"] == 'MOPO':
                    env = MOPOFakeEnv(self.dic_path, copy.deepcopy(self.dic_traffic_env_conf))
                elif self.dic_traffic_env_conf["OFFLINE_METHOD"] == "MOBASE":
                    env = MOBASEFakeEnv(os.path.join(self.dic_path["PATH_TO_OFFLINE_DATA"], "data"), copy.deepcopy(self.dic_traffic_env_conf))
                else:
                    raise NotImplementedError
                test = tester(self.dic_agent_conf, self.dic_path, self.dic_traffic_env_conf, fake_env=env)
                test.test_single_step()
                test.draw_scatter("step", "next_state", [0], "lane_num_waiting_vehicle_in")
                test.draw_line("step", "next_state", [0], "lane_num_waiting_vehicle_in")
                test.test()
                test.draw_line("iter", "next_state", [0], "lane_num_waiting_vehicle_in")
        else:
            env = CityFlowEnv(self.dic_path["PATH_TO_WORK_DIRECTORY"], self.dic_traffic_env_conf)
        
        for cnt_round in range(self.dic_traffic_env_conf["NUM_ROUNDS"]):
            print("round %d starts" % cnt_round)
            round_start_time = time.time()
            process_list = []

            if "OFFLINE_METHOD" not in self.dic_traffic_env_conf or self.dic_traffic_env_conf["OFFLINE_METHOD"] != "SIMPLE":
                print("==============  generator =============")
                '''
                    Use trained agents to run a round
                    Log all useful features in './inter_{inter_id}.pkl'
                '''
                generator_start_time = time.time()
                if multi_process:
                    print("-------------- use multi-process for generator -------------")
                    for cnt_gen in range(self.dic_traffic_env_conf["NUM_GENERATORS"]):
                        p = Process(target=generator_wrapper,
                                    args=(cnt_round, cnt_gen, self.dic_path,
                                          self.dic_agent_conf, self.dic_traffic_env_conf, self.offline, env)
                                    )
                        print("before")
                        p.start()
                        print("end")
                        process_list.append(p)
                    print("before join")
                    for i in range(len(process_list)):
                        p = process_list[i]
                        print("generator %d to join" % i)
                        p.join()
                        print("generator %d finish join" % i)
                    print("end join")
                else:
                    for cnt_gen in range(self.dic_traffic_env_conf["NUM_GENERATORS"]):
                        generator_wrapper(cnt_round=cnt_round,
                                          cnt_gen=cnt_gen,
                                          dic_path=self.dic_path,
                                          dic_agent_conf=self.dic_agent_conf,
                                          dic_traffic_env_conf=self.dic_traffic_env_conf, 
                                          offline = self.offline,
                                          env=env)
                generator_end_time = time.time()
                generator_total_time = generator_end_time - generator_start_time
            else:
                generator_total_time = 0.00
                for cnt_gen in range(self.dic_traffic_env_conf["NUM_GENERATORS"]):
                    generator_logger = logger(self.dic_path["PATH_TO_WORK_DIRECTORY"], self.dic_traffic_env_conf["NUM_INTERSECTIONS"])
                    generator_logger.update_path_to_log(cnt_round=cnt_round, cnt_gen=cnt_gen)
                    generator_logger.batch_log()

            print("==============  make samples =============")
            # make samples and determine which samples are good
            '''
                Use trained agents to run a round
                Log all useful features in './inter_{inter_id}.pkl'
            '''
            making_samples_start_time = time.time()
            train_round = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "train_round")
            if not os.path.exists(train_round):
                os.makedirs(train_round)
            cs = ConstructSample(path_to_samples=train_round, cnt_round=cnt_round,
                                 dic_traffic_env_conf=self.dic_traffic_env_conf, dic_path=self.dic_path)
            cs.make_reward_for_system()
            making_samples_end_time = time.time()
            making_samples_total_time = making_samples_end_time - making_samples_start_time

            print("==============  update network =============")
            update_network_start_time = time.time()
            if self.dic_traffic_env_conf["MODEL_NAME"] in self.dic_traffic_env_conf["LIST_MODEL_NEED_TO_UPDATE"]:
                if multi_process:
                    p = Process(target=updater_wrapper,
                                args=(cnt_round,
                                      self.dic_agent_conf,
                                      self.dic_traffic_env_conf,
                                      self.dic_path))
                    p.start()
                    print("update to join")
                    p.join()
                    print("update finish join")
                else:
                    updater_wrapper(cnt_round=cnt_round,
                                    dic_agent_conf=self.dic_agent_conf,
                                    dic_traffic_env_conf=self.dic_traffic_env_conf,
                                    dic_path=self.dic_path)

            update_network_end_time = time.time()
            update_network_total_time = update_network_end_time - update_network_start_time

            print("==============  test evaluation =============")
            test_evaluation_start_time = time.time()
            model_test.test(self.dic_path["PATH_TO_MODEL"], cnt_round,
                            self.dic_traffic_env_conf["TEST_COUNTS"], self.dic_traffic_env_conf, self.dic_path)

            test_evaluation_end_time = time.time()
            test_evaluation_total_time = test_evaluation_end_time - test_evaluation_start_time

            print("Generator time: ", generator_total_time)
            print("Making samples time:", making_samples_total_time)
            print("update_network time:", update_network_total_time)
            print("test_evaluation time:", test_evaluation_total_time)

            print("round {0} ends, total_time: {1}".format(cnt_round, time.time()-round_start_time))
            f_time = open(os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "running_time.csv"), "a")
            f_time.write("{0}\t{1}\t{2}\t{3}\t{4}\n".format(generator_total_time, making_samples_total_time,
                                                            update_network_total_time, test_evaluation_total_time,
                                                            time.time()-round_start_time))
            f_time.close()
