import json
import pickle

from env import TSCEnv
from utilities.utils import set_seed, get_agent, get_config, set_thread, set_logger, release_logger, make_dir
from utilities.snippets import run_an_episode
import numpy as np
import multiprocessing
import argparse
import os
from agent import GPLOptimizer, GPL, GPLP

parser = argparse.ArgumentParser(description='ATSC baselines')
parser.add_argument('--model', type=str, default='GPLP', help='[GPL, GPLP]') # GPL => GPLight | GPLP => GPLight+
parser.add_argument('--dataset', type=str, default='Hangzhou2',
                    help='[Hangzhou1, Hangzhou2, Hangzhou3, Manhattan, Atlanta, Jinan, LosAngeles]')
args = parser.parse_args()
DEBUG = True
cur_agent = args.model
data_name = args.dataset

if not DEBUG:
    save_path = os.path.join('result_models', cur_agent, data_name)
    # save_path = os.path.join('result', 'k15', data_name)
    os.makedirs(save_path, exist_ok=True)
    # make_dir('log/{}/{}/'.format(data_name, cur_agent))


def run_an_experiment(inter_name, flow_idx, seed):
    num_step = {'Atlanta': 900, 'Hangzhou1': 3600, 'Hangzhou2': 3600, 'Hangzhou3': 3600, 'Jinan': 3600,
                'LosAngeles': 1800}
    config = get_config()
    config.update({
        'inter_name': inter_name,
        'seed': seed,
        'flow_idx': flow_idx,  # 0~10
        "saveReplay": False,
        'save_result': not DEBUG,
        'dir': 'data/{}/'.format(inter_name),
        'flowFile': 'flow_{}.json'.format(flow_idx),
        'cur_agent': cur_agent,
        'render': False,
        'num_step': num_step[inter_name] if inter_name in num_step.keys() else 3600,
    })
    set_seed(seed)
    # set_thread()
    # set_logger(config)

    env = TSCEnv(config)
    env.n_agent = []
    for idx in range(env.n):
        agent_cls = GPLP if cur_agent=='GPLP' else GPL
        agent = agent_cls(config, env, idx)
        env.n_agent.append(agent)

    gp_opt = GPLOptimizer(env, config)
    expr, func = gp_opt.search(popsize=25, iterations=20)

    for agent in env.n_agent:
        agent.reset()
        agent.func = func
    info = run_an_episode(env, config, on_training=False, store_experience=False, learn=False)

    ret = (
        info['world_2_average_travel_time'][0],
        info['world_2_average_queue_length'][0],
        info['world_2_average_delay'][0],
        info['world_2_average_throughput'][0]
    )

    print(ret)

    if not DEBUG:
        js = {
            'prog': expr,
            'ATT': info['world_2_average_travel_time'][0],
            'AQL': info['world_2_average_queue_length'][0],
            'AD': info['world_2_average_delay'][0],
            'NT': info['world_2_average_throughput'][0],
        }
        with open(os.path.join(save_path, f'{flow_idx}.json'), 'w') as f:
            json.dump(js, f, indent=2)

    return


if __name__ == '__main__':

    parallel = not DEBUG

    if DEBUG:
        total_run = 1
    else:
        total_run = 10

    num_concurrent_p = total_run

    metrics = {
        'travel_time': [None for _ in range(total_run)],
        'queue_length': [None for _ in range(total_run)],
        'delay': [None for _ in range(total_run)],
        'throughput': [None for _ in range(total_run)]
    }
    seed_list = [992832, 284765, 905873, 776383, 198876, 192223, 223341, 182228, 885746, 992817]
    # seed_list = list(range(10))

    if parallel:
        with multiprocessing.Pool(processes=num_concurrent_p) as pool:
            n_return_value = pool.starmap(run_an_experiment,
                                          [(data_name, f_idx, seed_list[f_idx]) for f_idx in range(total_run)])
    else:
        for f_idx in range(0, total_run):
            return_value = run_an_experiment(inter_name=data_name, flow_idx=f_idx, seed=seed_list[f_idx])

