import os
import json
import pathlib
import numpy as np
from datetime import datetime
from tqdm import tqdm
import multiprocessing as mp

from utils import NumpyEncoder
from algos.mcts import MCTS
from envs.envs import get_env, Occupancy_MDP

DATA_FOLDER_PATH = str(pathlib.Path(__file__).parent) + '/data/'
print(DATA_FOLDER_PATH)

CONFIG = {
    "N": 10, # Number of experiments to run.
    "num_processors": 10,
    "env": "adversarial_mdp_10",
    "H": 100, # Truncation length.
    "n_iter_per_timestep": 1_000, # MCTS number of tree expansion steps per timestep.
}

def create_exp_name(args: dict) -> str:
    return args['env'] + '_' + args['algo'] + '_gamma_' + str(args['gamma']) + '_' + \
        str(datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))


def simulate_MCTS(mdp, H, n_iter_per_timestep=1_000):

    # Instantiate extended MDP.
    occupancy_mdp = Occupancy_MDP(mdp, H)

    # Sample initial state from the extended MDP.
    extended_state = occupancy_mdp.sample_initial_state()

    # Simulate until termination.
    cumulative_reward = 0.0
    for _ in tqdm(range(H)):

        mcts = MCTS(initial_state=extended_state, env=occupancy_mdp, K_ucb=np.sqrt(2), rollout_policy=None)
        mcts.learn(n_iters=n_iter_per_timestep)
        selected_action = mcts.best_action()

        # Environment step.
        extended_state, reward, terminated = occupancy_mdp.step(extended_state, selected_action)
        cumulative_reward += reward

    return cumulative_reward


def run(cfg, seed):

    print('Running seed=', seed)

    np.random.seed(seed)

    # Instantiate MDP.
    env = get_env(cfg["env"])
    print("env", env)

    mcts_f_val = simulate_MCTS(mdp=env,
                               H=cfg["H"],
                               n_iter_per_timestep=cfg["n_iter_per_timestep"])

    return mcts_f_val


def main(cfg):

    # Setup experiment data folder.
    env = get_env(cfg["env"])
    exp_name = create_exp_name({'env': cfg['env'],
                                'algo': "mcts",
                                'gamma': env['gamma']})
    exp_path = DATA_FOLDER_PATH + exp_name
    os.makedirs(exp_path, exist_ok=True)
    print('\nExperiment ID:', exp_name)
    print('Config:')
    print(cfg)

    # Simulate.
    print('\nSimulating...')

    with mp.Pool(processes=cfg["num_processors"]) as pool:
        f_vals = pool.starmap(run, [(cfg, t) for t in range(cfg["N"])])
        pool.close()
        pool.join()

    f_vals = np.array(f_vals)

    exp_data = {}
    exp_data["config"] = cfg
    exp_data["f_vals"] = f_vals
    exp_data["env"] = env
    exp_data["env"]["f"] = None

    # Dump dict.
    f = open(exp_path + "/exp_data.json", "w")
    dumped = json.dumps(exp_data, cls=NumpyEncoder)
    json.dump(dumped, f)
    f.close()

    return exp_name


if __name__ == "__main__":
    main(cfg = CONFIG)
