import pathlib
import pickle

import numpy as np
from joblib import Parallel, delayed

from mcts import MCTS
from tree_env import SyntheticTree


def experiment(algorithm, tree, tau, alpha, number_of_atom):
    mcts = MCTS(exploration_coeff=exploration_coeff,
                algorithm=algorithm,
                tau=tau,
                alpha=alpha,
                number_of_atoms=number_of_atom,
                step_size=step_size,
                gamma=gamma,
                update_type='mean')

    v_hat, regret = mcts.run(tree, n_simulations)
    diff = np.abs(v_hat - tree.optimal_v_root)
    diff_uct = np.abs(v_hat - tree.max_mean)

    return diff, diff_uct, regret


n_exp = 5
n_trees = 5
n_simulations = 1000

ks = [100, 200]
ds = [1, 2]

# ks = [50, 100, 200]
# ds = [1, 2]

# ks = [8]
# ds = [5]

exploration_coeff = 1.
tau = .1
gamma = 1.0
step_size = 0.2
# algorithms = {'uct': 'UCT', 'ments': 'MENTS', 'rents': 'RENTS', 'tents': 'TENTS', 'w-mcts': 'W-MCTS'}

# algorithms = {'alpha-divergence': 'ALPHA-1.5', 'tents': 'TENTS', 'alpha-divergence': 'ALPHA-4.0', 'alpha-divergence': 'ALPHA-8.0'}
# algorithms = {'uct': 'UCT', 'dng': 'DNG', 'fixed-depth-mcts': 'Fixed-Depth-MCTS',
#               'ments': 'MENTS', 'tents': 'TENTS', 'dents': 'DENTS', 'cats': 'CATS', 'pats': 'PATS'}

# algorithms = {'fixed-depth-mcts': 'Fixed-Depth-MCTS'}
algorithms = {'power-uct': 'Power-UCT'}
# algorithms = {'ments': 'MENTS', 'tents': 'TENTS', 'rents': 'RENTS'}
# algorithms = {'uct': 'UCT', 'dng': 'DNG', 'fixed-depth-mcts': 'Fixed-Depth-MCTS', 'ments': 'MENTS', 'rents': 'RENTS',
#               'tents': 'TENTS', 'cats': 'CATS', 'pats': 'PATS'}
# algorithms = {'cats': 'CATS', 'pats': 'PATS'}
alphas = [1, 2, 4, 8, 10, 16]
atoms = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
# alphas = [1.0]

folder_name = './log_5_5/expl_%.2f_tau_%.2f' % (exploration_coeff, tau)

diff_heatmap = np.zeros((len(alphas), len(ks), len(ds)))
diff_uct_heatmap = np.zeros_like(diff_heatmap)
regret_heatmap = np.zeros_like(diff_heatmap)
for x, k in enumerate(ks):
    for y, d in enumerate(ds):
        subfolder_name = folder_name + '/k_' + str(k) + '_d_' + str(d)
        pathlib.Path(subfolder_name).mkdir(parents=True, exist_ok=True)
        for number_of_atom in atoms:
            for t, alpha in enumerate(alphas):
                for z, alg in enumerate(algorithms.keys()):
                    if (alg in {'uct', 'dng', 'fixed-depth-mcts', 'power-uct', 'ments', 'tents', 'rents', 'pats', 'cats', 'dents'}
                            and number_of_atom > 10):
                        continue
                    if alg in {'pats', 'cats'} and alpha != 10:
                        continue
                    if alg in {'uct', 'dng', 'fixed-depth-mcts', 'ments', 'tents', 'rents', 'dents'} and alpha > 1:
                        continue
                    print('Branching factor: %d, Depth: %d, Alg: %s' % (k, d, alg))
                    out = list()
                    for w in range(n_trees):
                        try:
                            with open(subfolder_name + '/tree%d_%s_%f_%d.pkl' % (w, alg,alpha,number_of_atom), 'rb') as f:
                                tree = pickle.load(f)
                        except FileNotFoundError as err:
                            print('Tree not found! Creating new tree...')
                            tree = SyntheticTree(k, d, alg, tau, alpha, number_of_atom, gamma, step_size)
                            with open(subfolder_name + '/tree%d_%s_%f_%d.pkl' % (w, alg, alpha, number_of_atom), 'wb') as f:
                                pickle.dump(tree, f)

                        out += Parallel(n_jobs=-1)(delayed(experiment)(alg, tree, tau, alpha, number_of_atom) for _ in range(n_exp))
                    out = np.array(out)

                    diff = out[:, 0]
                    diff_uct = out[:, 1]
                    regret = out[:, 2]

                    avg_diff = diff.mean(0)
                    avg_diff_uct = diff_uct.mean(0)
                    avg_regret = regret.mean(0)
                    diff_heatmap[t, x, y] = avg_diff[-1]
                    diff_uct_heatmap[t, x, y] = avg_diff_uct[-1]
                    regret_heatmap[t, x, y] = avg_regret[-1]

                    np.save(subfolder_name + '/diff_%s_%f_%d.npy' % (alg,alpha,number_of_atom), diff)
                    np.save(subfolder_name + '/diff_uct_%s_%f_%d.npy' % (alg,alpha,number_of_atom), diff_uct)
                    np.save(subfolder_name + '/regret_%s_%f_%d.npy' % (alg,alpha,number_of_atom), regret)

np.save(folder_name + '/diff_heatmap.npy', diff_heatmap)
np.save(folder_name + '/diff_uct_heatmap.npy', diff_uct_heatmap)
np.save(folder_name + '/regret_heatmap.npy', regret_heatmap)
