import os

import seaborn
from matplotlib import pyplot as plt

from PROPS.plotting.utils import get_paths, load_data, plot, get_line_styles

if __name__ == "__main__":

    seaborn.set_theme(style='whitegrid')
    env_ids = ['Swimmer-v4', 'Hopper-v4', 'HalfCheetah-v4', 'Walker2d-v4', 'Ant-v4', 'Humanoid-v4']

    fig = plt.figure(figsize=(3*3,2*3))
    i = 1
    for env_id in env_ids:
        ax = plt.subplot(2, 3, i)
        i+=1

        path_dict_all = {}


        root_dir = f'../results_rl_ablations_clip_reg/{env_id}'

        key = rf'PROPS'
        algo = 'ppo_props'
        path_dict_aug = get_paths(
            results_dir=f'{root_dir}/{algo}/b_2',
            key=key,
            evaluations_name='evaluations')
        path_dict_all.update(path_dict_aug)


        key = rf'PROPS, no clipping'
        path_dict_aug = get_paths(
            results_dir=f'{root_dir}/{algo}/no_clip',
            key=key,
            evaluations_name='evaluations')
        path_dict_all.update(path_dict_aug)

        key = rf'PROPS, no regularization'
        path_dict_aug = get_paths(
            results_dir=f'{root_dir}/{algo}/no_reg',
            key=key,
            evaluations_name='evaluations')
        path_dict_all.update(path_dict_aug)

        root_dir = f'../results_rl/{env_id}'

        algo = 'ppo_buffer'
        key = rf'PPO-Privileged'
        path_dict_aug = get_paths(
            results_dir=f'{root_dir}/{algo}/priv/b_1',
            key=key,
            x_scale=0.5,
            evaluations_name='evaluations')
        path_dict_all.update(path_dict_aug)

        algo = 'ppo_buffer'
        key = rf'PPO-Buffer'
        path_dict_aug = get_paths(
            results_dir=f'{root_dir}/{algo}/b_2',
            key=key,
            evaluations_name='evaluations')
        path_dict_all.update(path_dict_aug)

        algo = 'ppo_buffer'
        key = rf'PPO'
        path_dict_aug = get_paths(
            results_dir=f'{root_dir}/{algo}/b_1',
            key=key,
            evaluations_name='evaluations')
        path_dict_all.update(path_dict_aug)

        plot(path_dict_all, name='returns')
        plt.title(f'{env_id}', fontsize=20)
        plt.xlabel('Timestep', fontsize=20)
        plt.ylabel('Return', fontsize=20)
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
        ax.xaxis.get_offset_text().set_fontsize(12)

        plt.tight_layout()

    fig.subplots_adjust(top=0.85)
    ax = fig.axes[1]
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=5, fontsize=12,)

    save_dir = f'figures'
    save_name = f'ablate_clip_regularization.png'
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(f'{save_dir}/{save_name}')

    plt.show()
