import os

import numpy as np
import seaborn
from matplotlib import pyplot as plt

from PROPS.plotting.utils import get_paths, load_data, plot, get_line_styles

ylims = {
    'Hopper-v4': (0, 4000),
    'HalfCheetah-v4': (0, 4000),
    'Walker2d-v4': (0, 400),
    'Ant-v4': (0, 6000),
    'InvertedPendulum-v4': (0, 1100),
    'InvertedDoublePendulum-v4': (0, 10000),
    'Acrobot-v1': (-500, 0),
    'CartPole-v1': (0, 600),
    'LunarLander-v2': (-250, 300),
    'Swimmer-v4': (0, 150),
    'Humanoid-v4': (0, 6500)
}


def plot(save_dict, name, m=100000, success_threshold=None, return_cutoff=-np.inf):
    i = 0

    # palette = seaborn.color_palette()
    print(os.getcwd())

    for agent, info in save_dict.items():
        paths = info['paths']
        x_scale = info['x_scale']
        max_t = info['max_t']
        avgs = []
        for path in paths:
            u, t, avg = load_data(path, name=name, success_threshold=success_threshold)
            if avg is not None:
                # print(len(avg))
                # if len(avg) < 50:
                #     continue
                if max_t:
                    cutoff = np.where(t <= max_t/x_scale)[0]
                    avg = avg[cutoff]
                    t = t[cutoff]

                elif m:
                    avg = avg[:m]
                avgs.append(avg)
                t_good = t

        if len(avgs) == 0:
            continue
        elif len(avgs) == 1:
            avg_of_avgs = avg
            q05 = np.zeros_like(avg)
            q95 = np.zeros_like(avg)

        else:

            min_l = np.inf
            for a in avgs:
                l = len(a)
                if l < min_l:
                    min_l = l

            if min_l < np.inf:
                for i in range(len(avgs)):
                    avgs[i] = avgs[i][:min_l]

            avg_of_avgs = np.average(avgs, axis=0)

            std = np.std(avgs, axis=0)
            N = len(avgs)
            ci = 1 * std / np.sqrt(N)
            q05 = avg_of_avgs + ci
            q95 = avg_of_avgs - ci

        style_kwargs = get_line_styles(agent)
        style_kwargs['linewidth'] = 2


        x = t_good * x_scale

        # t = None
        if t is None:
            x = np.arange(len(avg_of_avgs))

        style_kwargs['linewidth'] = 3

        if 'PROPS' == agent:
            style_kwargs['linestyle'] = '-'
            style_kwargs['linewidth'] = 6

        elif 'ppo_buffer' in agent or 'PPO-Buffer' in agent or 'b=' in agent or 'Buffer' in agent:
            style_kwargs['linestyle'] = '--'
        elif 'ppo,' in agent or 'PPO,' in agent or 'PPO with' in agent or 'PPO' == agent:
            style_kwargs['linestyle'] = ':'
        elif 'Priv' in agent:
            style_kwargs['linestyle'] = '-.'

        if m:
            x = x[:m]
            avg_of_avgs = avg_of_avgs[:m]
            q05 = q05[:m]
            q95 = q95[:m]

        l = len(avg_of_avgs)
        print(l)
        print(np.mean(avg_of_avgs), np.median(avg_of_avgs))
        plt.plot(x[:l], avg_of_avgs, label=agent, **style_kwargs)
        if style_kwargs['linestyle'] == 'None':
            plt.fill_between(x[:l], q05, q95, alpha=0)
        else:
            plt.fill_between(x[:l], q05, q95, alpha=0.2)

        i += 1

if __name__ == "__main__":

    seaborn.set_theme(style='whitegrid')
    env_ids = ['Swimmer-v4', 'Hopper-v4','HalfCheetah-v4', 'Walker2d-v4', 'Ant-v4', 'Humanoid-v4']

    fig = plt.figure(figsize=(6 * 5, 6))
    subplot_i = 0
    for env_id in env_ids:
        subplot_i += 1
        plt.subplot(1, 6, subplot_i)

        root_dir = f'../results_rl_ablations_clip_reg/{env_id}'

        path_dict_all = {}
        path_dict_all_ref = {}
        key = rf'PROPS'
        algo ='props'
        path_dict_aug = get_paths(
            results_dir=f'{root_dir}/{algo}/b_2/no_reg',
            key=key,
            evaluations_name='stats')
        path_dict_all.update(path_dict_aug)

        key = rf'OS'
        path_dict_aug = get_paths(
            results_dir=f'{root_dir}/{algo}/b_2/no_reg',
            key=key,
            evaluations_name='stats')
        path_dict_all_ref.update(path_dict_aug)

        name = 'kl_mle_target'
        name_ref = 'ref_kl_mle_target'

        plot(path_dict_all, name=name)
        plot(path_dict_all_ref, name=name_ref)

        plt.title(f'{env_id}', fontsize=24)
        plt.xlabel('Timestep', fontsize=24)
        plt.ylabel('Sampling Error', fontsize=24)
        plt.xticks(fontsize=24)
        plt.yticks(fontsize=24)
        ax = fig.axes[subplot_i-1]
        ax.xaxis.get_offset_text().set_fontsize(24)

        plt.tight_layout()

    fig.subplots_adjust(top=0.7)
    ax = fig.axes[0]
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=5, fontsize=36)

    save_dir = f'figures'
    save_name = f'se_updating_target.png'
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(f'{save_dir}/{save_name}')

    plt.show()
