import os

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np

plt.rcParams["font.family"] = "Times New Roman"

data_folder = "./data/"

algos = ["wo-r-nov", "wo-r-con", "durnd"]
envs = ['Freeway', 'Frogger', 'Solaris', 'BeamRider', "DefendLine",
        "SaveCenter", "CollectKit", "SlayGhosts", "ThreeRooms", "TMaze"]

labels = [r"DuRND with only $R^{\text{con}}$", r"DuRND with only $R^{\text{nov}}$", "DuRND complete"]
colors = ["mediumseagreen", "cornflowerblue", "firebrick"]

smooth_window = [2, 2, 2, 3, 4, 4, 2, 2, 1, 1]


def smooth(data, window_size=5):
    return np.convolve(data, np.ones(window_size) / window_size, mode='valid')


fig, axs = plt.subplots(2, 5, figsize=(26, 7.8))

for i in range(10):
    if i % 5 == 0:
        axs[i // 5][i % 5].set_ylabel('Episode returns', fontsize=24)
        axs[i // 5][i % 5].tick_params(axis='both', which='both')
        if i == 5:
            axs[i // 5][i % 5].set_xlabel(r'Steps $\times 10^3$', fontsize=22)

    for j in range(len(algos)):
        data_path = os.path.join(data_folder, envs[i], f"{algos[j]}.npy")

        data_steps, data_mean, data_std = np.load(data_path)

        # smooth the data
        window_size = smooth_window[i]
        smooth_mean = smooth(data_mean, window_size)
        smooth_std = smooth(data_std, window_size)
        smooth_steps = data_steps[:len(smooth_mean)]

        axs[i // 5][i % 5].fill_between(np.array(smooth_steps) / 1000,
                                        smooth_mean + smooth_std,
                                        smooth_mean - smooth_std,
                                        alpha=0.2, color=colors[j])

        axs[i // 5][i % 5].plot(np.array(smooth_steps) / 1000, smooth_mean, color=colors[j], label=labels[j],
                                linewidth=2.5)

        print(
            f"{envs[i]}-{algos[j]}: {np.mean(data_mean):.2f} +- {np.std(data_mean) * 0.01 / np.sqrt(len(data_mean)):.2f}")

    axs[i // 5][i % 5].set_title(envs[i], fontsize=28)

# get the legend from the first sub-figure
legend_handles, legend_labels = axs[0][0].get_legend_handles_labels()
# reorder the legend
order = [2, 1, 0]
handles_new = [legend_handles[i] for i in order]
labels_new = [legend_labels[i] for i in order]

legend_linewidth = 6
handles_new_thick = [Line2D([0], [0], color=handle.get_color(), lw=legend_linewidth) for handle in handles_new]

fig.legend(handles_new_thick, labels_new, loc='lower center', ncol=3, fontsize=22, frameon=False)

plt.subplots_adjust(bottom=0.14, hspace=0.26)

plt.savefig("./ablation-study.svg", bbox_inches='tight', pad_inches=0.05)

plt.show()
