import numpy as np
import os
import sys
from omegaconf import OmegaConf as om
import pandas
from matplotlib import pyplot as plt
from matplotlib import cm
from pudb import set_trace
from tueplots import bundles, figsizes

def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict

WINDOW = 10
#plt.style.use(['seaborn'])
# plt.rcParams.update(bundles.icml2022())
# plt.rcParams.update(figsizes.icml2022_half())
# bundle = bundles.icml2022()
# bundle['figure.figsize'] = (3.5, 2.8)
# bundle['font.size'] = 14
# bundle['axes.labelsize'] = 14
# bundle['legend.fontsize'] = 8
# bundle['xtick.labelsize'] = 10
# bundle['ytick.labelsize'] = 10
# bundle['axes.titlesize'] = 14
# plt.rcParams.update(bundle)
plt.rcParams.update(bundles.neurips2022())
fig_dict = figsizes.neurips2022(ncols=1)
fig_dict = change_figsize(fig_dict, fraction_width=0.8, fraction_height=0.6)
print(fig_dict)
plt.rcParams.update(fig_dict)
#bundle = bundles.icml2022()
#bundle['figure.figsize'] = (3.5, 2.8)
#bundle['font.size'] = 14
#bundle['axes.labelsize'] = 14
#bundle['legend.fontsize'] = 5
#bundle['xtick.labelsize'] = 14
#bundle['ytick.labelsize'] = 14
#bundle['axes.titlesize'] = 12
#plt.rcParams.update({"axes.formatter.limits": (0, 0)})
#print(bundles.icml2022()) 
YLABELS = {'test_success': 'success rate',
            'test_score': 'return',
            'train_score': 'train/return',
           'test_score_max': 'best rollout return',
           'test_success_max': 'best rollout success'}
def smooth(vals, window):
    '''Smooths values using a sliding window.'''

    if window > 1:
        if window > len(vals):
            window = len(vals)
        y = np.ones(window)
        x = vals
        z = np.ones(len(vals))
        mode = 'same'
        vals = np.convolve(x, y, mode) / np.convolve(z, y, mode)

    return vals


class DataClass:
    def __init__(self, runs):
        self.runs = runs
        self._active_runs = runs
        self.cleaned = 0

    @property
    def active_runs(self):
        '''
        Removes None runs, they usually appear during training
        '''
        if self.cleaned:
            return self._active_runs
        active_runs = []
        [active_runs.append(run) for run in self._active_runs if run is not None]
        self._active_runs = active_runs
        self.cleaned = 1
        return active_runs

    @active_runs.setter
    def active_runs(self, val):
        self._active_runs = val

    def filter_runs(self, key, val):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run[key] == val:
                active_runs.append(run)
                # print(f'Run {idx} taken for {key} is {val}')
        self.active_runs = active_runs
        print(f'{len(self.active_runs)} remaining.')
        return active_runs

    def cut_short_runs(self, threshold):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run['steps'][-1] > threshold:
                active_runs.append(run)
        self.active_runs = active_runs
        return active_runs

    def get_mean(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.mean(truncated_scores, axis=0), window=WINDOW)

    def get_max_stats(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        if METRIC == 'avg':
            truncated_max = [np.mean(x) for x in truncated_scores]
        else:
            truncated_max = [np.max(x) for x in truncated_scores]
        return np.mean(truncated_max), np.std(truncated_max)

    def get_run_scores(self, field):
        return [smooth(x[field][:self.smallest_length], window=WINDOW) for x in self.active_runs]

    def get_std(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.std(truncated_scores, axis=0), window=WINDOW)

    def reset(self):
        self.cleaned = 0
        self.active_runs = self.runs.copy()

    def get_max(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return np.max(truncated_scores)


    @property
    def steps(self):
        return self.active_runs[0]['steps'][:self.smallest_length]

    @property
    def smallest_length(self):
        return np.min([len(x['test_score']) for x in self.active_runs])

    @property
    def lengths(self):
        return [len(x['test_score']) for x in self.active_runs]

    def __len__(self):
        return len(self.runs)


def get_path():
    pass


def get_log_ids_sorted(path):
    '''atm just returns last '''
    log_ids = []
    for file in os.listdir(path):
        if file[:4] == 'log_':
            log_id = file.split('.')[0]
            log_ids.append(int(log_id[4:]))
    log_ids.sort()
    return log_ids


def get_scores(path):
    log_ids  = get_log_ids_sorted(path)
    scores = []
    steps = []
    data = pandas.read_csv(os.path.join(path, f'log.csv'))
    success = data.get('test/success_rate/mean').values
    success_max = data.get('test/success_rate/max').values
    score_train = data.get('train/episode_score/mean').values
    score_test = data.get('test/episode_score/mean').values
    score_test_max = data.get('test/episode_score/max').values
    step = data.get('train/steps').values
    for jdx in range(data.shape[0]):
        scores.append((score_train[jdx], score_test[jdx], success[jdx], score_test_max[jdx], success_max[jdx]))
        steps.append(step[jdx])
    return steps, *[[x[i] for x in scores] for i in range(len(scores[0]))]


def get_data(path):
    '''Returns a list of dicts, where each dict contains the path and the settings of each run'''
    runs = []
    for file in os.listdir(path):
        runs.append(get_run_dict(os.path.join(path, file)))
    return DataClass(runs)


def get_run_dict(path):
    if not os.path.isdir(os.path.join(path, 'checkpoints')):
        for folder in os.listdir(path):
            if os.path.isdir(os.path.join(path, folder)):
                return get_run_dict(os.path.join(path, folder))
    else:
        config = om.load(os.path.join(path, 'config.yaml'))
        steps, score_train, score_test, success, score_test_max, success_max = get_scores(path)
        run = {'steps': steps,
               'test_success': success,
               'test_success_max': success_max,
               'test_score': score_test,
               'test_score_max': score_test_max,
               'train_score': score_train,
               'environment': config.environment_name,
               'her': 'HER' in config.agent,
               'seed': config.seed,
               'force_scale': config.env_args['force_scale']}
        # print(config.path)
        dep_flag = 1 if "dep_factory(1, " in config.agent else 0
        dep_flag = 2 if "dep_factory(2," in config.agent else dep_flag
        dep_flag = 3 if "dep_factory(3," in config.agent else dep_flag
        dep_flag = 5 if "dep_factory(5," in config.agent else dep_flag
        dep_flag = 6 if "dep_factory(6," in config.agent else dep_flag
        dep_flag = 11 if "dep_factory(11," in config.agent else dep_flag
        dep_flag = 12 if "dep_factory(12," in config.agent else dep_flag
        dep_flag = 13 if "dep_factory(13," in config.agent else dep_flag
        # print(dep_flag)
        run['dep'] = dep_flag
        if config['env_args']:
            for k, v in config.env_args.items():
                run[k] = v
        return run


def prepare_runs(runs, cut, dep, her, force_scale, test_episode_every=0, with_dep_test_episode_every=0):
    if with_dep_test_episode_every:
        if not 'test_episode_every' in runs.active_runs[0]:
            raise Exception('Skipping this run')
    runs.cut_short_runs(int(cut))
    # print(runs.active_runs[0])
    print(f'{dep=} {her=} {force_scale=} {test_episode_every=}')
    runs.filter_runs('dep', dep)
    runs.filter_runs('her', her)
    if force_scale == 1:
        runs.filter_runs('force_scale', 0.003)
    else:
        runs.filter_runs('force_scale', 0)

    if test_episode_every <= 100 and test_episode_every != 0 :
        runs.filter_runs('test_episode_every', test_episode_every)
    elif test_episode_every == 0:
        pass
    else:
        runs.filter_runs('test_episode_every', '100000000')



def plot(fig, ax, path, field='test_success', cut=0):
    # fig.text(0.02, 0.85, 'A', fontsize=20)
    plot_single(ax, path, field, cut)

def plot_bar(fig, ax, path, field='test_success', cut=0, means=[], stds=[]):
    _plot_bar(path, field, cut, means, stds)


def _plot_bar(path, field='test_success', cut=0, means=[], stds=[]):
    print(cut)
    runs = get_data(os.path.join(path, 'working_directories'))
    if 'background' in path:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=13, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass

    elif not 'final' in path:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=1, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass

        try:
            runs.reset()
            prepare_runs(runs, cut, dep=1, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            print('InitDEP Force failed')
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=2, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except: 
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=2, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            #set_trace()
            runs.reset()
            prepare_runs(runs, cut, dep=3, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=3, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=0, test_episode_every=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=1, test_episode_every=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=0, test_episode_every=200, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=1, test_episode_every=200, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=5, her=0, force_scale=0, test_episode_every=3, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=5, her=0, force_scale=1, test_episode_every=3, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
    else:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=0, her=0, force_scale=0, test_episode_every=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
    return means, stds


def plot_single(ax, path, field, cut=0):
    print(cut)
    runs = get_data(os.path.join(path, 'working_directories'))
    if 'background' in path:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=13, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:green', label='InitialDEP-MPO')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:green')
        except:
            pass
    else:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=1, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:green', label='InitialDEP-MPO')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:green')
        except:
            pass

        try:
            runs.reset()
            prepare_runs(runs, cut, dep=1, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='tab:green', label='InitialDEP-MPO-Force')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:green')
        except:
            print('InitDEP Force failed')
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=2, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:orange', label='AverageDEP-MPO')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:orange')
            runs.reset()
            prepare_runs(runs, cut, dep=2, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='tab:orange', label='AverageDEP-MPO-Force')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:orange')
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=3, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:pink', label='Switch-MPO')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:pink')
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=3, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='tab:pink', label='Switch-MPO-Force')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:pink')
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=0, test_episode_every=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:purple', label='FULLDEP-MPO')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:purple')
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=1, test_episode_every=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='tab:purple', label='FULLDEP-MPO-Force')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:purple')
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=0, test_episode_every=200, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='black', label='FULLDEP-MPO-nogreedy')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='black')
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=1, test_episode_every=200, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='black', label='FULLDEP-MPO-Force-nogreedy')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='black')
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=5, her=0, force_scale=0, test_episode_every=3, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='yellow', label='Switcher-MPO-episode3')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='yellow')
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=5, her=0, force_scale=1, test_episode_every=3, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='brown', label='switcher-MPO-Force-episode3')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='brown')
        except:
            pass

    ax.set_xlabel('steps')
    if field == 'test_success':
        ax.set_ylim([0, 1.2])
    ax.set_ylabel(YLABELS[field])
    # ax.set_xlim([0, 1e8])

    # ax.legend(loc='lower right', frameon=True)
    # ax.set_xlim([0, 0.8e5])
    # ax.set_ylim([0, 6500])
    # ax.set_xlim([0,1e7])
    # ax.legend(loc='lower right', frameon=True)
    # if field == 'test_score':
    #     # ax.legend(loc='upper center', bbox_to_anchor=(0.45, 1.1),
    #     #                     ncol=5, frameon=False, shadow=False)
    #     ax.text(6000000, 6420.0, s='Text', fontsize=5, color='white')
    # else:
    #     ax.legend(loc='upper center', bbox_to_anchor=(0.45, 1.1),
    #                         ncol=5, frameon=False, shadow=False)
    #ax.legend()


if __name__ == '__main__':
    LINEPLOT = 1
    BARPLOT = 1
    prefix = '/folder/'
    fields = ['test_score', 'test_score_max']
    if LINEPLOT:
        METRIC = 'avg'
        fig, axs = plt.subplots(1, len(fields), sharey=False)
        for ax, field in zip(axs, fields):
            #for path in ['ostrich_tuned_final', 'ostrich_tuned_ou', 'ostrich_paper_td4_my_parallel','ostrich_tuned_colored_corrected']:
            for path in ['arm750_her/arm750_ablations_neurips/']:
            #for path in ['ostrich_tuned_ou']:
            # for path in ['ostrich_paper_td4_my_parallel']:
                path = f'{prefix}{path}'
                plot(fig, ax, path, field=field, cut=20000000)
            plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
            #plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
            #axs[0].set_yticks([0, 2000, 4000, 6000])
            #axs[0].set_ylabel(YLABELS[field])
            axs[0].set_ylabel('return')
            fig.savefig(f'./new_reacher_evaluate_{field}.pdf')
        # plt.show()
    if BARPLOT:
        metrics = ['max', 'avg']
        for METRIC in metrics:
            fields = ['test_success', 'test_success_max']
            fig, axs = plt.subplots(1, len(fields), sharey=True)
            for ax, field in zip(axs, fields):
                #for path in ['ostrich_tuned_final', 'ostrich_tuned_ou', 'ostrich_paper_td4_my_parallel','ostrich_tuned_colored_corrected']:
                means = []
                stds = []
                for path in ['arm750_her/arm750_optim_final_imol_switcher_working_FINAL',
                             'arm750_her/arm750_ablations_neurips/', 'arm750_her/arm750_depbackgroundablation']:
                    #for path in ['./arm750_her/arm750_ablations_neurips/',
                    #             './arm750_her/arm750_optim_final_imol_switcher_working_FINAL/']:
                    path = f'{prefix}{path}'
                    plot_bar(fig, ax, path, field=field, cut=20000000, means=means, stds=stds)
                labels = ['MPO', 'init', 'init-force', 'avg', 'avg-force', 'det', 'det-force', 'stoch',
                          r'\textbf{stoch-force}', 'stoch-force-noback']
                x = np.arange(len(means))[:len(labels)]
                means = means[:len(labels)]
                stds = stds[:len(labels)]
                bars = ax.bar(x, means, yerr=stds)
                colors = iter(cm.autumn(np.linspace(0.3, 1, 16)))
                #colors = ['tab:orange' for _ in range(len(labels))]
                for bar, color in zip(bars, colors):
                    bar.set_color(color)
                bars[0].set_color('tab:blue')

                ax.set_xticks(x)
                ax.set_xticklabels(labels, rotation=45)
                #plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
                #plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
                #axs[0].set_yticks([0, 2000, 4000, 6000])
                #axs[0].set_ylabel(YLABELS[field])
            #axs[0].set_ylim([0, 1.1])
            #axs[1].set_ylim([-600, 0])
            if METRIC == 'avg':
                axs[0].set_ylabel('cumul. success rate')
            else:
                axs[0].set_ylabel('max success rate')
            fig.savefig(f'./new_humanreacher_ablation_barplot_{field}_{METRIC}.pdf')
            # plt.show()
