import os

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np

plt.rcParams["font.family"] = "Times New Roman"

data_folder = "./data/"

algos = ["r_con_nom", "r_nov_nom"]
envs = ['Freeway', 'Frogger', 'Solaris', 'BeamRider', "DefendLine",
        "SaveCenter", "CollectKit", "SlayGhosts", "ThreeRooms", "TMaze"]

labels = [r"contribution rewards w/o $R^{\text{con}}$", r"novelty rewards $R^{\text{nov}}$"]
colors = ["mediumseagreen", "cornflowerblue"]

fig, axs = plt.subplots(2, 5, figsize=(26, 7.8))

for i in range(10):
    if i % 5 == 0:
        axs[i // 5][i % 5].set_ylabel('Episode returns', fontsize=24)
        axs[i // 5][i % 5].tick_params(axis='both', which='both')
        if i == 5:
            axs[i // 5][i % 5].set_xlabel(r'Steps $\times 10^3$', fontsize=22)

    for j in range(len(algos)):
        data_path = os.path.join(data_folder, envs[i], "rewards", f"{algos[j]}.npy")

        data_steps, data_mean, data_std = np.load(data_path)

        axs[i // 5][i % 5].fill_between(np.array(data_steps) / 1000,
                                        data_mean + data_std,
                                        data_mean - data_std,
                                        alpha=0.3, color=colors[j])

        axs[i // 5][i % 5].plot(np.array(data_steps) / 1000, data_mean, color=colors[j], label=labels[j], linewidth=2)

    axs[i // 5][i % 5].set_title(envs[i], fontsize=28)

# get the legend from the first sub-figure
legend_handles, legend_labels = axs[0][0].get_legend_handles_labels()
# reorder the legend
order = [1, 0]
handles_new = [legend_handles[i] for i in order]
labels_new = [legend_labels[i] for i in order]

legend_linewidth = 6
handles_new_thick = [Line2D([0], [0], color=handle.get_color(), lw=legend_linewidth) for handle in handles_new]

fig.legend(handles_new_thick, labels_new, loc='lower center', ncol=2, fontsize=22, frameon=False)
plt.subplots_adjust(bottom=0.14, hspace=0.26)

plt.savefig("./rewards.svg", bbox_inches='tight', pad_inches=0.05)

plt.show()
