import numpy as np
import os
import sys
from omegaconf import OmegaConf as om
import pandas
from matplotlib import pyplot as plt
from matplotlib import cm
from tueplots import bundles, figsizes
import csv


def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict
colors = ['tab:orange', 'tab:blue', 'tab:green', '#cd66b1', 'tab:red']
WINDOW = 10
plt.rcParams.update(bundles.neurips2022())
fig_dict = figsizes.neurips2022(ncols=1)
fig_dict = change_figsize(fig_dict, fraction_width=0.8, fraction_height=0.60)
print(fig_dict)
plt.rcParams.update(fig_dict)
YLABELS = {'test_success': 'success rate',
            'test_score': 'return',
            'train_score': 'train/return',
           'test_score_max': 'best rollout return'}
def smooth(vals, window):
    '''Smooths values using a sliding window.'''

    if window > 1:
        if window > len(vals):
            window = len(vals)
        y = np.ones(window)
        x = vals
        z = np.ones(len(vals))
        mode = 'same'
        vals = np.convolve(x, y, mode) / np.convolve(z, y, mode)

    return vals


class DataClass:
    def __init__(self, runs):
        self.runs = runs
        self._active_runs = runs
        self.cleaned = 0

    @property
    def active_runs(self):
        '''
        Removes None runs, they usually appear during training
        '''
        if self.cleaned:
            return self._active_runs
        active_runs = []
        [active_runs.append(run) for run in self._active_runs if run is not None]
        self._active_runs = active_runs
        self.cleaned = 1
        return active_runs

    @active_runs.setter
    def active_runs(self, val):
        self._active_runs = val

    def filter_runs(self, key, val):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run[key] == val:
                active_runs.append(run)
                # print(f'Run {idx} taken for {key} is {val}')
        self.active_runs = active_runs
        print(f'{len(self.active_runs)} remaining.')
        return active_runs

    def cut_short_runs(self, threshold):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run['steps'][-1] > threshold:
                active_runs.append(run)
        self.active_runs = active_runs
        return active_runs

    def get_mean(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.mean(truncated_scores, axis=0), window=WINDOW)

    def get_max_stats(self, field, stat='avg'):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        if METRIC == 'avg':
            truncated_max = [np.mean(x) for x in truncated_scores]
        else:
            truncated_max = [np.max(x) for x in truncated_scores]
        return np.mean(truncated_max), np.std(truncated_max)

    def get_run_scores(self, field):
        return [smooth(x[field][:self.smallest_length], window=WINDOW) for x in self.active_runs]

    def get_std(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.std(truncated_scores, axis=0), window=WINDOW)

    def reset(self):
        self.cleaned = 0
        self.active_runs = self.runs.copy()

    def get_max(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return np.max(truncated_scores)


    @property
    def steps(self):
        return self.active_runs[0]['steps'][:self.smallest_length]

    @property
    def smallest_length(self):
        return np.min([len(x['test_score']) for x in self.active_runs])

    @property
    def lengths(self):
        return [len(x['test_score']) for x in self.active_runs]

    def __len__(self):
        return len(self.runs)


def get_path():
    pass


def get_log_ids_sorted(path):
    '''atm just returns last '''
    log_ids = []
    for file in os.listdir(path):
        if file[:4] == 'log_':
            log_id = file.split('.')[0]
            log_ids.append(int(log_id[4:]))
    log_ids.sort()
    return log_ids


def get_scores(path):
    log_ids  = get_log_ids_sorted(path)
    scores = []
    steps = []
    data = pandas.read_csv(os.path.join(path, f'log.csv'))
    success = data.get('test/success_rate/mean').values
    score_train = data.get('train/episode_score/mean').values
    score_test = data.get('test/episode_score/mean').values
    score_test_max = data.get('test/episode_score/max').values
    step = data.get('train/steps').values
    for jdx in range(data.shape[0]):
        scores.append((score_train[jdx], score_test[jdx], success[jdx], score_test_max[jdx]))
        steps.append(step[jdx])
    return steps, *[[x[i] for x in scores] for i in range(4)]


def get_data(path):
    '''Returns a list of dicts, where each dict contains the path and the settings of each run'''
    runs = []
    for file in os.listdir(path):
        runs.append(get_run_dict(os.path.join(path, file)))
    return DataClass(runs)


def get_run_dict(path):
    if not os.path.isdir(os.path.join(path, 'checkpoints')):
        for folder in os.listdir(path):
            if os.path.isdir(os.path.join(path, folder)):
                return get_run_dict(os.path.join(path, folder))
    else:
        config = om.load(os.path.join(path, 'config.yaml'))
        steps, score_train, score_test, success, score_test_max  = get_scores(path)
        run = {'steps': steps,
               'test_success': success,
               'test_score': score_test,
               'test_score_max': score_test_max,
               'train_score': score_train,
               'environment': config.environment_name,
               'her': 'HER' in config.agent,
               'seed': config.seed,
               'force_scale': config.env_args['force_scale']}
        # print(config.path)
        param_choice = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(path))), 'param_choice.csv')
        with open(param_choice) as f:
            reader = csv.reader(f)
            lines = [x for x in reader]
        if 'DEP.test_episode_every' in lines[0]:
            idx = lines[0].index('DEP.test_episode_every')
            run['test_episode_every'] = float(lines[-1][idx])
        dep_flag = 1 if "dep_factory(1," in config.agent else 0
        dep_flag = 2 if "dep_factory(2," in config.agent else dep_flag
        dep_flag = 3 if "dep_factory(3," in config.agent else dep_flag
        dep_flag = 5 if "dep_factory(5," in config.agent else dep_flag
        dep_flag = 6 if "dep_factory(6," in config.agent else dep_flag
        dep_flag = 11 if "dep_factory(11," in config.agent else dep_flag
        dep_flag = 12 if "dep_factory(12," in config.agent else dep_flag
        dep_flag = 13 if "dep_factory(13," in config.agent else dep_flag
        # print(dep_flag)
        run['dep'] = dep_flag
        if config['env_args']:
            for k, v in config.env_args.items():
                run[k] = v
        return run


def prepare_runs(runs, cut, dep, her, force_scale, test_episode_every=0, with_dep_test_episode_every=0):
    if with_dep_test_episode_every:
        if not 'test_episode_every' in runs.active_runs[0]:
            raise Exception('Skipping this run')
    runs.cut_short_runs(int(cut))
    # print(runs.active_runs[0])
    print(f'{dep=} {her=} {force_scale=} {test_episode_every=}')
    runs.filter_runs('dep', dep)
    runs.filter_runs('her', her)
    if force_scale == 1:
        runs.filter_runs('force_scale', 0.003)
    else:
        runs.filter_runs('force_scale', 0)

    if test_episode_every <= 100 and test_episode_every != 0 :
        runs.filter_runs('test_episode_every', test_episode_every)
    elif test_episode_every == 0:
        pass
    else:
        runs.filter_runs('test_episode_every', '100000000')


def plot(fig, ax, path, field='test_success', cut=0):
    # fig.text(0.02, 0.85, 'A', fontsize=20)
    plot_single(ax, path, field, cut)


def plot_bar(fig, ax, path, field='test_success', cut=0, means=[], stds=[]):
    _plot_bar(path, field, cut, means, stds)

def plot_single(ax, path, field, cut=0):
    print(cut)
    runs = get_data(os.path.join(path, 'working_directories'))
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=1, her=0, force_scale=0, with_dep_test_episode_every=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:green', label='InitialDEP-MPO')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:green')
    except:
        pass

    try:
        runs.reset()
        prepare_runs(runs, cut, dep=1, her=0, force_scale=1, with_dep_test_episode_every=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='tab:green', label='InitialDEP-MPO-Force')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:green')
    except:
        print('InitDEP Force failed')
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=2, her=0, force_scale=0, with_dep_test_episode_every=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:orange', label='AverageDEP-MPO')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:orange')
        runs.reset()
        prepare_runs(runs, cut, dep=2, her=0, force_scale=1, with_dep_test_episode_every=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='tab:orange', label='AverageDEP-MPO-Force')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:orange')
    except:
        pass
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=3, her=0, force_scale=0, with_dep_test_episode_every=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:pink', label='Switch-MPO')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:pink')
    except:
        pass
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=3, her=0, force_scale=1, with_dep_test_episode_every=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='tab:pink', label='Switch-MPO-Force')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:pink')
    except:
        pass
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=6, her=0, force_scale=0, test_episode_every=0, with_dep_test_episode_every=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:purple', label='FULLDEP-MPO')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:purple')
    except:
        pass
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=6, her=0, force_scale=1, test_episode_every=0, with_dep_test_episode_every=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='tab:purple', label='FULLDEP-MPO-Force')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:purple')
    except:
        pass
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=6, her=0, force_scale=0, test_episode_every=200, with_dep_test_episode_every=1)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='black', label='FULLDEP-MPO-nogreedy')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='black')
    except:
        pass
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=6, her=0, force_scale=1, test_episode_every=200, with_dep_test_episode_every=1)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='black', label='FULLDEP-MPO-Force-nogreedy')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='black')
    except:
        pass
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=5, her=0, force_scale=0, test_episode_every=3, with_dep_test_episode_every=1)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='yellow', label='Switcher-MPO-episode3')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='yellow')
    except:
        pass
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=5, her=0, force_scale=1, test_episode_every=3, with_dep_test_episode_every=1)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), linestyle='--', color='brown', label='switcher-MPO-Force-episode3')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='brown')
    except:
        pass

    #ax.set_xlabel('steps')
    if field == 'test_success':
        ax.set_ylim([0, 1.2])
    #ax.set_ylabel(YLABELS[field])
    # ax.set_xlim([0, 1e8])

    # ax.legend(loc='lower right', frameon=True)
    # ax.set_xlim([0, 0.8e5])
    # ax.set_ylim([0, 6500])
    # ax.set_xlim([0,1e7])
    # ax.legend(loc='lower right', frameon=True)
    # if field == 'test_score':
    #     # ax.legend(loc='upper center', bbox_to_anchor=(0.45, 1.1),
    #     #                     ncol=5, frameon=False, shadow=False)
    #     ax.text(6000000, 6420.0, s='Text', fontsize=5, color='white')
    # else:
    #     ax.legend(loc='upper center', bbox_to_anchor=(0.45, 1.1),
    #                         ncol=5, frameon=False, shadow=False)
    #ax.legend()

def _plot_bar(path, field='test_success', cut=0, means=[], stds=[]):
    print(cut)
    runs = get_data(os.path.join(path, 'working_directories'))
    if 'background' in path:
        from pudb import set_trace
        set_trace()
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=13, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
    elif not 'final' in path:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=1, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass

        try:
            runs.reset()
            prepare_runs(runs, cut, dep=1, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            print('InitDEP Force failed')
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=2, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=2, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=3, her=0, force_scale=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=3, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=0, test_episode_every=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=1, test_episode_every=0, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=0, test_episode_every=200, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0, force_scale=1, test_episode_every=200, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=5, her=0, force_scale=0, test_episode_every=3, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=5, her=0, force_scale=1, test_episode_every=3, with_dep_test_episode_every=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass
    else:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=0, her=0, force_scale=1, with_dep_test_episode_every=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            pass

    #ax.set_xlabel('steps')
    if field == 'test_success':
        ax.set_ylim([0, 1.2])
    #ax.set_ylabel(YLABELS[field])


if __name__ == '__main__':
    LINEPLOT = 0
    BARPLOT = 1
    prefix = '/folder/'
    fields = ['test_score', 'test_score_max']
    if LINEPLOT:
        fig, axs = plt.subplots(1, len(fields), sharey=True)
        for ax, field in zip(axs, fields):
            #for path in ['ostrich_tuned_final', 'ostrich_tuned_ou', 'ostrich_paper_td4_my_parallel','ostrich_tuned_colored_corrected']:
            for path in ['ostrich_ablations']:
            #for path in ['ostrich_tuned_ou']:
            # for path in ['ostrich_paper_td4_my_parallel']:
                path = f'{prefix}{path}'
                plot(fig, ax, path, field=field, cut=4000000)
            plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
            #plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
            axs[0].set_yticks([0, 2000, 4000, 6000])
            #axs[0].set_ylabel(YLABELS[field])
            axs[0].set_ylabel('return')
            fig.savefig(f'./new_ostrich_evaluate_{field}_{METRIC}.pdf')
        # plt.show()
    if BARPLOT:
        for METRIC in ['avg', 'max']:
            fields = ['test_score', 'test_score_max']
            fig, axs = plt.subplots(1, len(fields), sharey=True)
            for ax, field in zip(axs, fields):
                #for path in ['ostrich_tuned_final', 'ostrich_tuned_ou', 'ostrich_paper_td4_my_parallel','ostrich_tuned_colored_corrected']:
                means = []
                stds = []
                for path in ['ostrich_tuned_final', 'ostrich_ablations', 'ostrich_depbackgroundablation']:
                    path = f'{prefix}{path}'
                    #means, stds = plot_bar(fig, ax, path, field=field, cut=90000000, means=means, stds=stds)
                    plot_bar(fig, ax, path, field=field, cut=40000000, means=means, stds=stds)
                #plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
                #plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
                #axs[0].set_yticks([0, 2000, 4000, 6000])
                #axs[0].set_ylabel(YLABELS[field])
                labels = ['MPO', 'init', 'init-force', 'avg', 'avg-force', 'det', 'det-force', 'stoch',
                          r'\textbf{stoch-force}',
                          'stoch-force-noback']
                x = np.arange(len(means))[:len(labels)]
                means = means[:len(labels)]
                bars = ax.bar(x, means, yerr=stds)
                colors = iter(cm.autumn(np.linspace(0.3, 1, 16)))
                #colors = ['tab:orange' for _ in range(len(labels))]
                for bar, color in zip(bars, colors):
                    bar.set_color(color)
                bars[0].set_color('tab:blue')
                ax.set_xticks(x)
                ax.set_xticklabels(labels, rotation= 45)
                #ax.set_xlabel('label', color='white')



            #axs[0].set_title('mean rollout')
            #axs[1].set_title('best rollout')
            if METRIC == 'avg':
                axs[0].set_ylabel('cumul. return')
            else:
                axs[0].set_ylabel('maximum return')
            #axs[1].set_ylabel('return')
            #plt.show()
            fig.savefig(f'./new_ostrich_ablations_barplot_{field}_{METRIC}.pdf')
