import numpy as np
import os
import sys
from omegaconf import OmegaConf as om
import pandas
from matplotlib import pyplot as plt
from tueplots import bundles, figsizes
import scipy.stats
from plot_utils import barplot_annotate_brackets


def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict
colors = ['tab:orange', 'tab:blue', 'tab:green', '#cd66b1', 'tab:purple', 'tab:brown']
WINDOW = 1
#plt.rcParams.update(bundles.neurips2022())
plt.rcParams.update(bundles.iclr2023())
#fig_dict = figsizes.neurips2022(ncols=1)
fig_dict = figsizes.iclr2023(ncols=1)
fig_dict = change_figsize(fig_dict, fraction_width=0.5, fraction_height=0.35)
print(fig_dict)
plt.rcParams.update(fig_dict)
YLABELS = {'test_success': 'success rate',
            'test_score': 'return',
            'train_score': 'train/return',
           'test_score_max': 'best rollout return'}

def smooth(vals, window):
    '''Smooths values using a sliding window.'''

    if window > 1:
        if window > len(vals):
            window = len(vals)
        y = np.ones(window)
        x = vals
        z = np.ones(len(vals))
        mode = 'same'
        vals = np.convolve(x, y, mode) / np.convolve(z, y, mode)

    return vals


class DataClass:
    def __init__(self, runs):
        self.runs = runs
        self._active_runs = runs
        self.cleaned = 0

    @property
    def active_runs(self):
        '''
        Removes None runs, they usually appear during training
        '''
        if self.cleaned:
            return self._active_runs
        active_runs = []
        [active_runs.append(run) for run in self._active_runs if run is not None]
        self._active_runs = active_runs
        self.cleaned = 1
        return active_runs

    @active_runs.setter
    def active_runs(self, val):
        self._active_runs = val

    def filter_runs(self, key, val):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run[key] == val:
                active_runs.append(run)
                # print(f'Run {idx} taken for {key} is {val}')
        self.active_runs = active_runs
        print(f'{len(self.active_runs)} remaining.')
        return active_runs

    def cut_short_runs(self, threshold):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run['steps'][-1] > threshold:
                active_runs.append(run)
        self.active_runs = active_runs
        return active_runs

    def get_mean(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.mean(truncated_scores, axis=0), window=WINDOW)

    def get_max_stats(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        truncated_max = [np.max(x) for x in truncated_scores]
        return np.mean(truncated_max), np.std(truncated_max)

    def get_max_per_run(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return [np.max(x) for x in truncated_scores]

    def get_run_scores(self, field):
        return [smooth(x[field][:self.smallest_length], window=WINDOW) for x in self.active_runs]

    def get_std(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.std(truncated_scores, axis=0), window=WINDOW)

    def get_max_per_run(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return [np.mean(x) for x in truncated_scores]

    def reset(self):
        self.cleaned = 0
        self.active_runs = self.runs.copy()

    def get_max(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return np.max(truncated_scores)


    @property
    def steps(self):
        return self.active_runs[0]['steps'][:self.smallest_length]

    @property
    def smallest_length(self):
        return np.min([len(x['test_score']) for x in self.active_runs])

    @property
    def lengths(self):
        return [len(x['test_score']) for x in self.active_runs]

    def __len__(self):
        return len(self.runs)


def get_path():
    pass


def get_log_ids_sorted(path):
    '''atm just returns last '''
    log_ids = []
    for file in os.listdir(path):
        if file[:4] == 'log_':
            log_id = file.split('.')[0]
            log_ids.append(int(log_id[4:]))
    log_ids.sort()
    return log_ids


def get_scores(path):
    log_ids  = get_log_ids_sorted(path)
    scores = []
    steps = []
    data = pandas.read_csv(os.path.join(path, f'log.csv'))
    success = data.get('test/success_rate/mean').values
    score_train = data.get('train/episode_score/mean').values
    score_test = data.get('test/episode_score/mean').values
    score_test_max = data.get('test/episode_score/max').values
    step = data.get('train/steps').values
    for jdx in range(data.shape[0]):
        scores.append((score_train[jdx], score_test[jdx], success[jdx], score_test_max[jdx]))
        steps.append(step[jdx])
    return steps, *[[x[i] for x in scores] for i in range(4)]


def get_data(path):
    '''Returns a list of dicts, where each dict contains the path and the settings of each run'''
    runs = []
    for file in os.listdir(path):
        runs.append(get_run_dict(os.path.join(path, file)))
    return DataClass(runs)


def get_run_dict(path):
    if not os.path.isdir(os.path.join(path, 'checkpoints')):
        for folder in os.listdir(path):
            if os.path.isdir(os.path.join(path, folder)):
                return get_run_dict(os.path.join(path, folder))
    else:
        config = om.load(os.path.join(path, 'config.yaml'))
        steps, score_train, score_test, success, score_test_max  = get_scores(path)
        run = {'steps': steps,
               'test_success': success,
               'test_score': score_test,
               'test_score_max': score_test_max,
               'train_score': score_train,
               'environment': config.environment_name,
               'her': 'HER' in config.agent,
               'seed': config.seed,
               'steps_before': 1 if "3e5" in config.agent else 0}
        try:
            run['force_scale'] = config.env_args['force_scale']
        except:
            print('No force scale attribute')
        # print(config.path)
        dep_flag = 1 if "dep_factory(1, " in config.agent else 0
        dep_flag = 2 if "dep_factory(2," in config.agent else dep_flag
        dep_flag = 3 if "dep_factory(3," in config.agent else dep_flag
        dep_flag = 6 if "dep_factory(6," in config.agent else dep_flag
        dep_flag = 11 if "dep_factory(11," in config.agent else dep_flag
        dep_flag = 12 if "dep_factory(12," in config.agent else dep_flag
        # print(dep_flag)
        run['dep'] = dep_flag
        if config['env_args']:
            for k, v in config.env_args.items():
                run[k] = v
        return run


def prepare_runs(runs, cut, dep, her, force_scale=0):
    runs.cut_short_runs(int(cut))
    # print(runs.active_runs[0])
    print(f'{dep=} {her=}')
    runs.filter_runs('dep', dep)
    runs.filter_runs('her', her)
    if force_scale == 1:
        runs.filter_runs('force_scale', 0.0)
    if len(runs.active_runs) > 0:
        print(f'{dep=} {her=}')
        print(np.mean(runs.get_mean('test_score')))
        print(np.mean(runs.get_std('test_score')))
    #  runs.filter_runs('random_goals', 0)


def plot(fig, ax, path, field='test_success', cut=0):
    # fig.text(0.02, 0.85, 'A', fontsize=20)
    plot_single(ax, path, field, cut)

def plot_bar(fig, ax, path, field='test_success', cut=0, means=[], stds=[]):
    means, stds = _plot_bar(path, field, cut, means, stds)
    return means, stds


def _plot_bar(path, field='test_success', cut=0, means=[], stds=[]):
    runs = get_data(os.path.join(path, 'working_directories'))

    if 'ablations' in path:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=1, her=0, force_scale=1)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except Exception as e:
            print(f'depmpo failed: {e}')
    else:
        if 'td4' in path:
            pass
            #try:
            #    runs.reset()
            #    prepare_runs(runs, cut, dep=6, her=0)
            #    print(f'{field} max={runs.get_max(field)}')
            #    mean, std = runs.get_max_stats(field)
            #    means.append(mean)
            #    stds.append(std)
            #except:
            #    print('dep td4 failed')
        else:
            try:
                runs.reset()
                prepare_runs(runs, cut, dep=6, her=0)
                print(f'{field} max={runs.get_max(field)}')
                mean, std = runs.get_max_stats(field)
                means.append(mean)
                stds.append(std)
            except Exception as e:
                print(f'depmpo failed: {e}')
        if 'td4' in path:
            try:
                runs.reset()
                prepare_runs(runs, cut, dep=0, her=0)
                print(f'{field} max={runs.get_max(field)}')
                mean, std = runs.get_max_stats(field)
                means.append(mean)
                stds.append(std)
            except Exception as e:
                print(f'td4 failed: {e}')
        else:
            try:
                runs.reset()
                prepare_runs(runs, cut, dep=0, her=0)
                print(f'{field} max={runs.get_max(field)}')
                mean, std = runs.get_max_stats(field)
                means.append(mean)
                stds.append(std)
                np.save('run_means_mpo.npy', runs.get_max_per_run(field))
            except:
                print('mpo failed')
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=11, her=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            print('colored mpo failed')
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=12, her=0)
            print(f'{field} max={runs.get_max(field)}')
            mean, std = runs.get_max_stats(field)
            means.append(mean)
            stds.append(std)
        except:
            print('ou mpo failed')
    return means, stds


def plot_single(ax, path, field, cut=0):
    print(cut)
    if field == 'test_score_max':
        plot_marker = True
    else:
        plot_marker = False
    runs = get_data(os.path.join(path, 'working_directories'))
    if 'td4' in path:
        pass
        #try:
        #    runs.reset()
        #    prepare_runs(runs, cut, dep=6, her=0)
        #    print(f'{field} max={runs.get_max(field)}')
        #    ax.plot(runs.steps, runs.get_mean(field), '-', color='#cd66b1', label='DEP-TD4')
        #    ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
        #                    alpha=0.1, color='#cd66b1')
        #    #ax.plot(runs.steps, runs.get_mean(field), '--', color=[0, 102/255, 0], label='DEP-TD4')
        #    #ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
        #    #                alpha=0.1, color=[0, 102/255, 0])
        #except:
        #    print('dep td4 failed')
    else:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=6, her=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:orange', label='DEP-MPO')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:orange')
            if plot_marker:
                #ax.plot(21e6, 6806, marker='x', markersize=3, color='#A93226')
                #ax.plot(21e6, 0, zorder=10, clip_on=False, marker='d', markersize=3, color='#A93226')
                ax.plot(21e6, 0, zorder=10, clip_on=False, marker='d', markersize=3, color='tab:orange')

            np.save(f'run_means_depmpo_ostrich.npy', runs.get_max_per_run(field))
        except:
            print('depmpo failed')
    if 'td4' in path:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=0, her=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:green', label='TD4')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field), alpha=0.1, color='tab:green')
            if plot_marker:
                #ax.plot(10.3e6, 3996, marker='x', markersize=3, color='#19721B')
                #ax.plot(10.3e6, -200, marker='d', zorder=10, clip_on=False, markersize=3, color='#19721B')
                ax.plot(10.3e6, 0, marker='d', zorder=10, clip_on=False, markersize=3, color='tab:green')
                line.set_clip_on(False)
            np.save(f'run_means_td4_ostrich.npy', runs.get_max_per_run(field))
        except:
            print('td4 failed')
    else:
        try:
            runs.reset()
            prepare_runs(runs, cut, dep=0, her=0)
            print(f'{field} max={runs.get_max(field)}')
            ax.plot(runs.steps, runs.get_mean(field), color='tab:blue', label='MPO')
            ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                            alpha=0.1, color='tab:blue')
            if plot_marker:
                #ax.plot(93e6, 6164, marker='x', markersize=3, color='#1F618D')
                #ax.plot(93e6, -200, marker='d', zorder=10, clip_on=False, markersize=3, color='#1F618D')
                ax.plot(93e6, -0, marker='d', zorder=10, clip_on=False, markersize=3, color='tab:blue')
                line.set_clip_on(False)
            np.save(f'run_means_mpo_ostrich.npy', runs.get_max_per_run(field))
        except:
            print('mpo failed')
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=11, her=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:purple', label='MPO-Colored')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:purple')
    except:
        print('colored mpo failed')
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=12, her=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:brown', label='MPO-OU')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:brown')
    except:
        print('ou mpo failed')
    ax.set_xlabel('steps')
    if field == 'test_success':
        ax.set_ylim([0, 1.2])
    #ax.set_ylabel(YLABELS[field])
    ax.set_xlim([0, 1.45e8])
    #ax.set_xlim([0, 0.53e8])
    #ax.set_xticks([0, 0.5e8, 1.0e8])


if __name__ == '__main__':
    LINEPLOT = 1
    BARPLOT = 0
    prefix = 'folder'
    fields = ['test_score', 'test_score_max']
    if LINEPLOT:
        fig, axs = plt.subplots(1, len(fields), sharey=True)
        for ax, field in zip(axs, fields):
            for path in ['ostrich_tuned_final', 'ostrich_paper_td4_my_parallel']:
                path = f'{prefix}{path}'
                print(f'path=')
                plot(fig, ax, path, field=field, cut=110000000)
            plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
            #plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
            axs[0].set_yticks([0, 2000, 4000, 6000])
            axs[0].set_ylabel('return')
            axs[0].set_ylim([0, 6500])
            axs[1].set_ylim([0, 6500])
            #axs[0].set_ylabel(YLABELS[field])
            fig.savefig(f'./new_ostrich_evaluate_main_{field}.pdf')
        # plt.show()
    if BARPLOT:
        #plt.rcParams.update(bundles.neurips2022())
        plt.rcParams.update(bundles.iclr2023())
        #fig_dict = figsizes.neurips2022(ncols=1)
        fig_dict = figsizes.iclr2023(ncols=1)
        fig_dict = change_figsize(fig_dict, fraction_width=0.3, fraction_height=0.4)
        print(fig_dict)
        plt.rcParams.update(fig_dict)
        #fields = ['test_score', 'test_score_max']
        fields = ['test_score']
        fig, axs = plt.subplots(1, len(fields), sharey=True)
        axs = [axs]
        for ax, field in zip(axs, fields):
            means = []
            stds = []
            for path in ['ostrich_tuned_final','ostrich_tuned_final_td4']: # 'ostrich_paper_td4_my_parallel']:
                path = f'{prefix}{path}'
                means, stds = plot_bar(fig, ax, path, field=field, cut=90000000, means=means, stds=stds)
                #means, stds = plot_bar(fig, ax, path, field=field, cut=5000000, means=means, stds=stds)
        plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
        #plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
        #axs[0].set_yticks([0, 2000, 4000, 6000])
        #axs[0].set_ylabel(YLABELS[field])
        x = np.arange(len(means))
        bars = ax.bar(x, means, yerr=stds)
        ax.set_ylim([0, 8000])
        for bar, color in zip(bars, colors):
            bar.set_color(color)
        data1 = np.load('run_means_depmpo.npy')
        data2 = np.load('run_means_mpo.npy')
        val = scipy.stats.ttest_ind(data1, data2).pvalue
        barplot_annotate_brackets(fig, ax, 0, 1, val, x, means, yerr=stds, fs=8, maxasterix=4)
        labels = ['DEP-MPO', 'MPO', 'TD4']
        #from pudb import set_trace; set_trace()
        ax.set_xticks(x)
        ax.set_xticklabels(labels)
        ax.set_xlabel('label', color='white')

        axs[0].set_ylabel('max return')
        #plt.show()
        print('saving')
        fig.savefig(f'./performance_histo.pdf')
