import numpy as np
import os
import sys
from omegaconf import OmegaConf as om
import pandas
from matplotlib import pyplot as plt
from tueplots import bundles, figsizes

def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict
colors = ['tab:orange', 'tab:blue', '#cd66b1','tab:green', 'tab:purple', 'tab:brown']
WINDOW = 10
plt.rcParams.update(bundles.neurips2022())
fig_dict = figsizes.neurips2022(ncols=1)
fig_dict = change_figsize(fig_dict, fraction_width=0.66, fraction_height=0.45)
print(fig_dict)
plt.rcParams.update(fig_dict)
YLABELS = {'test_success': 'success rate',
            'test_score': 'return',
            'train_score': 'train/return',
           'test_score_max': 'best rollout return'}

def smooth(vals, window):
    '''Smooths values using a sliding window.'''

    if window > 1:
        if window > len(vals):
            window = len(vals)
        y = np.ones(window)
        x = vals
        z = np.ones(len(vals))
        mode = 'same'
        vals = np.convolve(x, y, mode) / np.convolve(z, y, mode)

    return vals


class DataClass:
    def __init__(self, runs):
        self.runs = runs
        self._active_runs = runs
        self.cleaned = 0

    @property
    def active_runs(self):
        '''
        Removes None runs, they usually appear during training
        '''
        if self.cleaned:
            return self._active_runs
        active_runs = []
        [active_runs.append(run) for run in self._active_runs if run is not None]
        self._active_runs = active_runs
        self.cleaned = 1
        return active_runs

    @active_runs.setter
    def active_runs(self, val):
        self._active_runs = val

    def filter_runs(self, key, val):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run[key] == val:
                active_runs.append(run)
                # print(f'Run {idx} taken for {key} is {val}')
        self.active_runs = active_runs
        print(f'{len(self.active_runs)} remaining.')
        return active_runs

    def cut_short_runs(self, threshold):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run['steps'][-1] > threshold:
                active_runs.append(run)
        self.active_runs = active_runs
        return active_runs

    def get_mean(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.mean(truncated_scores, axis=0), window=WINDOW)

    def get_max_stats(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        truncated_max = [np.max(x) for x in truncated_scores]
        return np.mean(truncated_max), np.std(truncated_max)

    def get_run_scores(self, field):
        return [smooth(x[field][:self.smallest_length], window=WINDOW) for x in self.active_runs]

    def get_std(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.std(truncated_scores, axis=0), window=WINDOW)

    def reset(self):
        self.cleaned = 0
        self.active_runs = self.runs.copy()

    def get_max(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return np.max(truncated_scores)


    @property
    def steps(self):
        return self.active_runs[0]['steps'][:self.smallest_length]

    @property
    def smallest_length(self):
        return np.min([len(x['test_score']) for x in self.active_runs])

    @property
    def lengths(self):
        return [len(x['test_score']) for x in self.active_runs]

    def __len__(self):
        return len(self.runs)


def get_path():
    pass


def get_log_ids_sorted(path):
    '''atm just returns last '''
    log_ids = []
    for file in os.listdir(path):
        if file[:4] == 'log_':
            log_id = file.split('.')[0]
            log_ids.append(int(log_id[4:]))
    log_ids.sort()
    return log_ids


def get_scores(path):
    log_ids  = get_log_ids_sorted(path)
    scores = []
    steps = []
    data = pandas.read_csv(os.path.join(path, f'log.csv'))
    success = data.get('test/success_rate/mean').values
    score_train = data.get('train/episode_score/mean').values
    score_test = data.get('test/episode_score/mean').values
    score_test_max = data.get('test/episode_score/max').values
    step = data.get('train/steps').values
    for jdx in range(data.shape[0]):
        scores.append((score_train[jdx], score_test[jdx], success[jdx], score_test_max[jdx]))
        steps.append(step[jdx])
    return steps, *[[x[i] for x in scores] for i in range(4)]


def get_data(path):
    '''Returns a list of dicts, where each dict contains the path and the settings of each run'''
    runs = []
    for file in os.listdir(path):
        runs.append(get_run_dict(os.path.join(path, file)))
    return DataClass(runs)


def get_run_dict(path):
    if not os.path.isdir(os.path.join(path, 'checkpoints')):
        for folder in os.listdir(path):
            if os.path.isdir(os.path.join(path, folder)):
                return get_run_dict(os.path.join(path, folder))
    else:
        config = om.load(os.path.join(path, 'config.yaml'))
        steps, score_train, score_test, success, score_test_max  = get_scores(path)
        run = {'steps': steps,
               'test_success': success,
               'test_score': score_test,
               'test_score_max': score_test_max,
               'train_score': score_train,
               'environment': config.environment_name,
               'her': 'HER' in config.agent,
               'seed': config.seed,
               'steps_before': 1 if "3e5" in config.agent else 0}
        try:
            run['force_scale'] = config.env_args['force_scale']
        except:
            print('No force scale attribute')
        # print(config.path)
        dep_flag = 1 if "dep_factory(1, " in config.agent else 0
        dep_flag = 2 if "dep_factory(2," in config.agent else dep_flag
        dep_flag = 3 if "dep_factory(3," in config.agent else dep_flag
        dep_flag = 6 if "dep_factory(6," in config.agent else dep_flag
        dep_flag = 11 if "dep_factory(11," in config.agent else dep_flag
        dep_flag = 12 if "dep_factory(12," in config.agent else dep_flag
        # print(dep_flag)
        run['dep'] = dep_flag
        if config['env_args']:
            for k, v in config.env_args.items():
                run[k] = v
        return run


def prepare_runs(runs, cut, dep, her, force_scale=0):
    runs.cut_short_runs(int(cut))
    # print(runs.active_runs[0])
    print(f'{dep=} {her=}')
    runs.filter_runs('dep', dep)
    runs.filter_runs('her', her)
    if force_scale == 1:
        runs.filter_runs('force_scale', 0.0)
    #  runs.filter_runs('random_goals', 0)


def plot(fig, ax, path, field='test_success', cut=0):
    # fig.text(0.02, 0.85, 'A', fontsize=20)
    plot_single(ax, path, field, cut)

def plot_bar(fig, ax, path, field='test_success', cut=0, means=[], stds=[]):
    means, stds = _plot_bar(path, field, cut, means, stds)
    return means, stds


def _plot_bar(path, field='test_success', cut=0, means=[], stds=[]):
    runs = get_data(os.path.join(path, 'working_directories'))

    if 'ablations' in path:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=1, her=0, force_scale=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            print('depmpo failed')
    else:
        if 'td4' in path:
            try:
                runs.reset()
                prepare_runs(runs, cut, dep=6, her=0)
                print(f'{field} max={runs.get_max(field)}')
                mean, std = runs.get_max_stats(field)
                means.append(mean)
                stds.append(std)
            except:
                print('dep td4 failed')
        else:
            try:
                runs.reset()
                prepare_runs(runs, cut, dep=6, her=0)
                print(f'{field} max={runs.get_max(field)}')
                mean, std = runs.get_max_stats(field)
                means.append(mean)
                stds.append(std)
            except:
                print('depmpo failed')
        if 'td4' in path:
            try:
                runs.reset()
                prepare_runs(runs, cut, dep=0, her=0)
                print(f'{field} max={runs.get_max(field)}')
                mean, std = runs.get_max_stats(field)
                means.append(mean)
                stds.append(std)
            except:
                print('td4 failed')
        else:
            try:
                runs.reset()
                prepare_runs(runs, cut, dep=0, her=0)
                print(f'{field} max={runs.get_max(field)}')
                mean, std = runs.get_max_stats(field)
                means.append(mean)
                stds.append(std)
            except:
                print('mpo failed')
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=11, her=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            print('colored mpo failed')
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=12, her=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            print('ou mpo failed')
    return means, stds


def plot_single(ax, path, field, cut=0):
    print(cut)
    runs = get_data(os.path.join(path, 'working_directories'))
    if 'td4' in path:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), '-', color='#cd66b1', label='DEP-TD4')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='#cd66b1')
            #ax.plot(runs.steps, runs.get_mean(field), '--', color=[0, 102/255, 0], label='DEP-TD4')
            #ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
            #                alpha=0.1, color=[0, 102/255, 0])
        except:
            print('dep td4 failed')
    else:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:orange', label='DEP-MPO')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:orange')
        except:
            print('depmpo failed')
    if 'td4' in path:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=0, her=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:green', label='TD4')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:green')
        except:
            print('td4 failed')
    else:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=0, her=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:blue', label='MPO')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:blue')
        except:
            print('mpo failed')
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=11, her=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:purple', label='MPO-Colored')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:brown')
    except:
        print('colored mpo failed')
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=12, her=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:brown', label='MPO-OU')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:purple')
    except:
        print('ou mpo failed')
    ax.set_xlabel('steps')
    if field == 'test_success':
        ax.set_ylim([0, 1.2])
    #ax.set_ylabel(YLABELS[field])
    #ax.set_xlim([0, 1e8])
    ax.set_xlim([0, 0.53e8])
    #ax.set_xticks([0, 0.5e8, 1.0e8])


if __name__ == '__main__':
    LINEPLOT = 0
    BARPLOT = 1
    prefix = '/folder/'
    fields = ['test_score', 'test_score_max']
    if LINEPLOT:
        fig, axs = plt.subplots(1, len(fields), sharey=True)
        for ax, field in zip(axs, fields):
            #for path in ['ostrich_tuned_final', 'ostrich_paper_td4_my_parallel']:
            for path in ['ostrich_tuned_final', 'ostrich_paper_td4_my_parallel', 'ostrich_tuned_ou',
                         'ostrich_tuned_colored_corrected']:
            #for path in ['ostrich_tuned_ou',
            #             'ostrich_tuned_colored_corrected']:
            #for path in ['ostrich_tuned_colored_corrected']:
                path = f'{prefix}{path}'
                #plot(fig, ax, path, field=field, cut=90000000)
                plot(fig, ax, path, field=field, cut=50000000)
            plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
            #plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
            axs[0].set_yticks([0, 2000, 4000, 6000])
            axs[0].set_ylabel('return')
            #axs[0].set_ylabel(YLABELS[field])
            fig.savefig(f'./new_ostrich_evaluate_baselines_{field}.pdf')
        # plt.show()
    if BARPLOT:
        fields = ['test_score', 'test_score_max']
        fig, axs = plt.subplots(1, len(fields), sharey=True)
        for ax, field in zip(axs, fields):
            means = []
            stds = []
            #for path in ['ostrich_tuned_final', 'ostrich_tuned_ou', 'ostrich_paper_td4_my_parallel','ostrich_tuned_colored_corrected']:
            for path in ['ostrich_tuned_final', 'ostrich_paper_td4_my_parallel', 'ostrich_tuned_ou',
                         'ostrich_tuned_colored_corrected']:
            #for path in ['ostrich_tuned_ou', 'ostrich_tuned_colored_corrected']:
                path = f'{prefix}{path}'
                #means, stds = plot_bar(fig, ax, path, field=field, cut=90000000, means=means, stds=stds)
                means, stds = plot_bar(fig, ax, path, field=field, cut=5000000, means=means, stds=stds)
            plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
            #plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
            #axs[0].set_yticks([0, 2000, 4000, 6000])
            #axs[0].set_ylabel(YLABELS[field])
            x = np.arange(len(means))
            bars = ax.bar(x, means, yerr=stds)
            for bar, color in zip(bars, colors):
                bar.set_color(color)
            labels = ['DEP-MPO', 'MPO', 'DEP-TD4', 'TD4', 'OU-MPO', r'$\beta$-MPO']
            #labels = ['OU-MPO', 'betampo']
            #labels = ['DEP-MPO', 'MPO', 'DEPinit', 'TD4', 'DEP-TD4']
            ax.set_xticks(x)
            ax.set_xticklabels(labels, rotation=45)
            #ax.set_xlabel('label', color='white')

            axs[0].set_ylabel('max return')
            #plt.show()
            print('saving')
            fig.savefig(f'./new_ostrich_barplot_baselines_{field}.pdf')
