import os

import seaborn
from matplotlib import pyplot as plt

from PROPS.plotting.utils import get_paths, load_data, plot, get_line_styles

if __name__ == "__main__":

    seaborn.set_theme(style='whitegrid')
    env_ids = ['Swimmer-v4', 'Hopper-v4', 'HalfCheetah-v4', 'Walker2d-v4', 'Ant-v4', 'Humanoid-v4']
    # env_ids = ['Hopper-v4']

    fig = plt.figure(figsize=(3*3,2*3))
    i = 1
    for env_id in env_ids:
        ax = plt.subplot(2, 3, i)
        i+=1
        for b in [2,4,8]:

            path_dict_all = {}

            root_dir = f'../results_rl_ablations_b/{env_id}'
            algo = 'ppo_props'
            key = rf'$b = {b}$'
            path_dict_aug = get_paths(
                results_dir=f'{root_dir}/{algo}/b_{b}/',
                key=key,
                evaluations_name='evaluations')
            path_dict_all.update(path_dict_aug)

            plot(path_dict_all, name='returns')
            plt.title(f'{env_id}', fontsize=20)
            plt.xlabel('Timestep', fontsize=20)
            plt.ylabel('Return', fontsize=20)
            plt.xticks(fontsize=14)
            plt.yticks(fontsize=14)
            plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
            ax.xaxis.get_offset_text().set_fontsize(14)

            plt.tight_layout()


    fig.subplots_adjust(top=0.85)
    ax = fig.axes[1]
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=4, fontsize=17)

    save_dir = f'figures'
    save_name = f'ablation_b.png'
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(f'{save_dir}/{save_name}')

    plt.show()
