import numpy as np
import matplotlib.pyplot as plt


n_exp = 5
n_trees = 5
n_simulations = 1000
# k_heat = [50, 100, 200]  #[50, 100, 200]
# d_heat = [1, 2] # [50, 100, 200]
k_heat = [2, 4, 6, 8, 10, 12, 14, 16]  #[50, 100, 200]
d_heat = [1, 2, 3, 4] # [50, 100, 200]
# k = [50, 50, 100, 100, 200, 200]
# d = [1, 2, 1, 2, 1, 2]
k = [ 16, 16, 14, 16, 16]
d = [ 1 , 2, 3, 3, 4]
# k = [50, 100, 200, 50, 100, 200]
# d = [1, 1, 1, 2, 2, 2]
exploration_coeff = 1.
tau = .1
# algs = ['uct', 'ments', 'rents', 'tents']

# algs = ['cats', 'pats', 'uct', 'dng']
algs = ['uct', 'power-uct', 'dng', 'fixed-depth-mcts', 'ments', 'rents', 'tents', 'dents', 'cats', 'pats']

# algs = ['cats']

# algs_legend = ["p=1", "p=2", "p=4", "p=8", "p=10", "p=16"]
# algs_legend = ["cats, p=8", "pats, p=8", "uct", "dng"]
algs_legend = ["UCT", "Power-UCT", "DNG", "Fixed-Depth-MCTS", "MENTS", "RENTS", "TENTS", "BTS", "CATS", "PATS"]



algs_legend_color = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
                     '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

# alphas = [1, 2, 4, 8, 10, 16]
alphas = [8]

folder_name = './log_5_5/expl_%.2f_tau_%.2f' % (exploration_coeff, tau)
folder_name_uct = './log_5_5/expl_%.2f_tau_%.2f' % (exploration_coeff, tau)

# PLOTS
plt.figure()

count_plot = 0
for kk, dd in zip(k, d):
    max_diff = 0
    max_diff_uct = 0
    for alpha in alphas:
        for alg in algs:
            subfolder_name = None
            if alg in {'dng', 'ments', 'rents', 'tents'}:
                subfolder_name = folder_name_uct + '/k_%d_d_%d' % (kk, dd)
                alpha = 1
            elif alg == 'bts':
                exploration_coeff = .75
                tau = .5
                folder_name_dents = './log_5_5/expl_%.2f_tau_%.2f' % (exploration_coeff, tau)
                subfolder_name = folder_name_dents + '/k_%d_d_%d' % (kk, dd)
                alpha = 1
            elif alg in {'uct', 'fixed-depth-mcts'}:
                exploration_coeff = .05
                tau = .1
                folder_name_dents = './log_5_5/expl_%.2f_tau_%.2f' % (exploration_coeff, tau)
                subfolder_name = folder_name_dents + '/k_%d_d_%d' % (kk, dd)
                alpha = 1
            elif alg in {'power-uct'}:
                subfolder_name = folder_name_uct + '/k_%d_d_%d' % (kk, dd)
                alpha = 1
            elif alg in {'pats', 'cats'}:
                subfolder_name = folder_name_uct + '/k_%d_d_%d' % (kk, dd)
                alpha = 10
            else:
                subfolder_name = folder_name + '/k_%d_d_%d' % (kk, dd)
            atoms = 10
            if alg in {"power-uct"}:
                atoms = 10
            elif alg in {"cat", "pats"}:
                atoms = 10
            diff_uct = np.load(subfolder_name + '/diff_uct_%s_%f_%d.npy' % (alg,alpha,atoms))
            avg_diff_uct = diff_uct.mean(0)
            plt.subplot(3, len(k), 1 + count_plot % len(k))
            plt.title('k=%d  d=%d' % (kk, dd), fontsize='small')
            plt.tick_params(
                axis='x',
                which='both',
                bottom=False,
                top=False,
                labelbottom=False)
            plt.yticks(fontsize='small')
            if count_plot == 0:
                plt.ylabel(r'Value Estimation Error', fontsize='small')
            # if alg == "cats":
            #     plt.plot(avg_diff_uct, linewidth=1, linestyle='dashed', marker='|', markersize=0.5)
            if alg in {"ments", "rents", "tents"}:
                plt.plot(avg_diff_uct, linewidth=1, linestyle='dotted', markersize=0.5)
            else:
                plt.plot(avg_diff_uct, linewidth=1)

            err = 2 * np.std(diff_uct.reshape(n_exp * n_trees, n_simulations),
                             axis=0) / np.sqrt(n_exp * n_trees)
            plt.fill_between(np.arange(n_simulations), avg_diff_uct - err, avg_diff_uct + err,
                             alpha=.1, label='_nolegend_')
            max_diff_uct = max(max_diff_uct, avg_diff_uct.max())

            # diff_uct = np.load(subfolder_name + '/diff_uct_%s_%f.npy' % (alg,alpha))
            # avg_diff_uct = diff_uct.mean(0)
            # plt.subplot(3, len(k), len(k) + 1 + count_plot % len(k))
            # plt.tick_params(
            #     axis='x',
            #     which='both',
            #     bottom=False,
            #     top=False,
            #     labelbottom=False)
            # plt.yticks(fontsize='x-large')
            # if count_plot == 0:
            #     plt.ylabel(r'$\varepsilon_{UCT}$', fontsize='x-large')
            # plt.plot(avg_diff_uct, linewidth=1)
            # err = 2 * np.std(diff_uct.reshape(n_exp * n_trees, n_simulations),
            #                  axis=0) / np.sqrt(n_exp * n_trees)
            # plt.fill_between(np.arange(n_simulations), avg_diff_uct - err,
            #                  avg_diff_uct + err, alpha=.1)
            # max_diff_uct = max(max_diff_uct, avg_diff_uct.max())

            # regret = np.load(subfolder_name + '/regret_%s_%f.npy' % (alg,alpha))
            # avg_regret = regret.mean(0)
            # plt.subplot(3, len(k), 2 * len(k) + 1 + count_plot % len(k))
            # if count_plot == 0:
            #     plt.ylabel(r'$R$', fontsize='x-large')
            # plt.plot(avg_regret, linewidth=1)
            # err = 2 * np.std(regret.reshape(n_exp * n_trees, n_simulations),
            #                  axis=0) / np.sqrt(n_exp * n_trees)
            # plt.fill_between(np.arange(n_simulations), avg_regret - err,
            #                  avg_regret + err, alpha=.1)
            # max_regret = max(max_regret, avg_regret.max())
            plt.xticks([0, 500, 1000], ['0', '500', '1000'], fontsize='small')
            plt.xlabel('0     500    1000\n# Simulations', fontsize='small')
            plt.yticks(fontsize='small')
            plots = [max_diff_uct]

            plt.subplots_adjust(hspace=1., wspace=.4)

    plt.subplot(3, len(k), count_plot + 1 + 0 * len(k))
    plt.grid()
    plt.ylim(0, max_diff_uct)

    count_plot += 1

# plt.subplot(3, len(k), 3 * len(k) - 2)
# plt.legend([alg for alg in algs_legend], fontsize='small', loc="upper center", bbox_to_anchor=(-1.6, -.2),
#            ncol=len(algs_legend), frameon=False)

# Create a legend with 4 columns
# plt.legend(algs_legend, fontsize='small', loc="upper center", bbox_to_anchor=(-1.6, -.2), ncol=4, frameon=False)

plt.legend(algs_legend, fontsize='small', loc="upper center", bbox_to_anchor=(-2.8, -.5), ncol=5, frameon=False)



# ax = plt.gca()
# leg = ax.get_legend()
# leg.legendHandles[0].set_color(algs_legend_color[0])
# leg.legendHandles[1].set_color(algs_legend_color[1])
# leg.legendHandles[2].set_color(algs_legend_color[2])
# leg.legendHandles[3].set_color(algs_legend_color[3])
# leg.legendHandles[4].set_color(algs_legend_color[4])
# leg.legendHandles[5].set_color(algs_legend_color[5])


plt.savefig("results.pdf", bbox_inches='tight', pad_inches=0.1)

# HEATMAPS
# diff = np.load(folder_name + '/diff_heatmap.npy')
# diff_uct = np.load(folder_name + '/diff_uct_heatmap.npy')
# regret = np.load(folder_name + '/regret_heatmap.npy')
#
# diffs = [diff, diff_uct, regret]
# titles_diff = [r'$\varepsilon_\Omega$', r'$\varepsilon_{UCT}$', 'R']
# for t, d in zip(titles_diff, diffs):
#     fig, axs = plt.subplots(nrows=3, ncols=2)
#     fig.suptitle(t, fontsize='x-large')
#     max_d = d.max()
#     for i, ax in enumerate(axs.flat):
#         im = ax.imshow(d[i], cmap=plt.get_cmap('inferno'))
#         ax.set_title(algs_legend[i], fontsize='x-large')
#         ax.set_xticks(np.arange(len(d_heat)))
#         for tick in ax.xaxis.get_major_ticks():
#             tick.label.set_fontsize('x-large')
#         for tick in ax.yaxis.get_major_ticks():
#             tick.label.set_fontsize('x-large')
#         ax.set_yticks(np.arange(len(k_heat)))
#         ax.set_xticklabels(d_heat)
#         ax.set_yticklabels(k_heat)
#         im.set_clim(0, max_d)
#
#     cbar = fig.colorbar(im, ax=axs[:, 1], shrink=0.6)
#     for t in cbar.ax.get_yticklabels():
#         t.set_fontsize('x-large')

# plt.plot()
plt.show()
