import numpy as np
import os
import sys
from omegaconf import OmegaConf as om
import pandas
from matplotlib import pyplot as plt
from pudb import set_trace
from tueplots import bundles, figsizes

def change_figsize(fig_dict, fraction_width, fraction_height=1):
    fig_dict['figure.figsize'] = fig_dict['figure.figsize'][0] * fraction_width, fig_dict['figure.figsize'][1] * fraction_height
    return fig_dict

WINDOW = 5
#plt.style.use(['seaborn'])
# plt.rcParams.update(bundles.icml2022())
# plt.rcParams.update(figsizes.icml2022_half())
# bundle = bundles.icml2022()
# bundle['figure.figsize'] = (3.5, 2.8)
# bundle['font.size'] = 14
# bundle['axes.labelsize'] = 14
# bundle['legend.fontsize'] = 8
# bundle['xtick.labelsize'] = 10
# bundle['ytick.labelsize'] = 10
# bundle['axes.titlesize'] = 14
# plt.rcParams.update(bundle)
plt.rcParams.update(bundles.neurips2022())
fig_dict = figsizes.neurips2022()
fig_dict = change_figsize(fig_dict, fraction_width=1.0, fraction_height=0.7)
#print(fig_dict)
plt.rcParams.update(fig_dict)
#bundle = bundles.icml2022()
#bundle['figure.figsize'] = (3.5, 2.8)
#bundle['font.size'] = 14
#bundle['axes.labelsize'] = 14
#bundle['legend.fontsize'] = 5
#bundle['xtick.labelsize'] = 14
#bundle['ytick.labelsize'] = 14
#bundle['axes.titlesize'] = 12
#plt.rcParams.update({"axes.formatter.limits": (0, 0)})
#print(bundles.icml2022()) 
YLABELS = {'test_success': 'success rate',
        'test_score': 'return',
            'train_score': 'train/return',
           'test_score_max': 'best rollout return'}
def smooth(vals, window):
    '''Smooths values using a sliding window.'''

    if window > 1:
        if window > len(vals):
            window = len(vals)
        y = np.ones(window)
        x = vals
        z = np.ones(len(vals))
        mode = 'same'
        vals = np.convolve(x, y, mode) / np.convolve(z, y, mode)

    return vals


class DataClass:
    def __init__(self, runs):
        self.runs = runs
        self._active_runs = runs
        self.cleaned = 0

    @property
    def active_runs(self):
        '''
        Removes None runs, they usually appear during training
        '''
        if self.cleaned:
            return self._active_runs
        active_runs = []
        [active_runs.append(run) for run in self._active_runs if run is not None]
        self._active_runs = active_runs
        self.cleaned = 1
        return active_runs

    @active_runs.setter 
    def active_runs(self, val):
        self._active_runs = val

    def filter_runs(self, key, val):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run[key] == val:
                active_runs.append(run)
                # print(f'Run {idx} taken for {key} is {val}')
        self.active_runs = active_runs
        print(f'{len(self.active_runs)} remaining.')
        return active_runs

    def cut_short_runs(self, threshold):
        active_runs = []
        for idx, run in enumerate(self.active_runs):
            if run['steps'][-1] > threshold:
                active_runs.append(run)
        self.active_runs = active_runs
        return active_runs

    def get_mean(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.mean(truncated_scores, axis=0), window=WINDOW)

    def get_run_scores(self, field):
        return [smooth(x[field][:self.smallest_length], window=WINDOW) for x in self.active_runs]

    def get_std(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return smooth(np.std(truncated_scores, axis=0), window=WINDOW)

    def reset(self):
        self.cleaned = 0
        self.active_runs = self.runs.copy()

    def get_max(self, field):
        truncated_scores = [x[field][:self.smallest_length] for x in self.active_runs]
        return np.max(truncated_scores)


    @property
    def steps(self):
        return self.active_runs[0]['steps'][:self.smallest_length]

    @property
    def smallest_length(self):
        return np.min([len(x['test_score']) for x in self.active_runs])

    @property
    def lengths(self):
        return [len(x['test_score']) for x in self.active_runs]

    def __len__(self):
        return len(self.runs)


def get_path():
    pass


def get_log_ids_sorted(path):
    '''atm just returns last '''
    log_ids = []
    for file in os.listdir(path):
        if file[:4] == 'log_':
            log_id = file.split('.')[0]
            log_ids.append(int(log_id[4:]))
    log_ids.sort()
    return log_ids


def get_scores(path):
    log_ids  = get_log_ids_sorted(path)
    scores = []
    steps = []
    data = pandas.read_csv(os.path.join(path, f'log.csv'))
    success = data.get('test/success_rate/mean').values
    score_train = data.get('train/episode_score/mean').values
    score_test = data.get('test/episode_score/mean').values
    score_test_max = data.get('test/episode_score/max').values
    step = data.get('train/steps').values
    for jdx in range(data.shape[0]):
        scores.append((score_train[jdx], score_test[jdx], success[jdx], score_test_max[jdx]))
        steps.append(step[jdx])
    return steps, *[[x[i] for x in scores] for i in range(4)]


def get_data(path):
    '''Returns a list of dicts, where each dict contains the path and the settings of each run'''
    runs = []
    for file in os.listdir(path):
        runs.append(get_run_dict(os.path.join(path, file)))
    return DataClass(runs)


def get_run_dict(path):
    if not os.path.isdir(os.path.join(path, 'checkpoints')):
        for folder in os.listdir(path):
            if os.path.isdir(os.path.join(path, folder)):
                return get_run_dict(os.path.join(path, folder))
    else:
        config = om.load(os.path.join(path, 'config.yaml'))
        steps, score_train, score_test, success, score_test_max  = get_scores(path)
        run = {'steps': steps,
               'test_success': success,
               'test_score': score_test,
               'test_score_max': score_test_max,
               'train_score': score_train,
               'environment': config.environment_name,
               'her': 'HER' in config.agent,
               'seed': config.seed,
               'steps_before': 1 if "3e5" in config.agent else 0}
        try:
            run['force_scale'] = config.env_args['force_scale']
        except:
            pass

        dep_flag = 1 if "dep_factory(1, " in config.agent else 0
        dep_flag = 2 if "dep_factory(2," in config.agent else dep_flag
        dep_flag = 3 if "dep_factory(3," in config.agent else dep_flag
        dep_flag = 6 if "dep_factory(6," in config.agent else dep_flag
        dep_flag = 11 if "dep_factory(11," in config.agent else dep_flag
        dep_flag = 12 if "dep_factory(12," in config.agent else dep_flag
        # print(dep_flag)
        run['dep'] = dep_flag
        if config['env_args']:
            for k, v in config.env_args.items():
                run[k] = v
        return run


def prepare_runs(runs, cut, dep, her, force_scale=0):
    runs.cut_short_runs(int(cut))
    # print(runs.active_runs[0])
    print(f'{dep=} {her=}')
    runs.filter_runs('dep', dep)
    runs.filter_runs('her', her)
    if force_scale:
        runs.filter_runs('force_scale', 0.0)
    #  runs.filter_runs('random_goals', 0)


def plot(fig, ax, path, limes, cut=0, field='test_success'):
    # fig.text(0.02, 0.85, 'A', fontsize=20)
    plot_single(ax, path, field, limes, cut)


def plot_single(ax, path, field, limes, cut=0):
    print(cut)
    runs = get_data(os.path.join(path, 'working_directories'))
    #dep = 6 if 'arm26' in path else 3
    dep = 6
    #if not 'arm750' in path or 'ablations' in path:
    try:
        runs.reset()
        if 'ablations' in path:
            prepare_runs(runs, cut, dep=dep, her=0, force_scale=1)
        else:
            prepare_runs(runs, cut, dep=dep, her=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:orange', label='DEP-MPO')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:orange')
    except:
        print('depmpo failed')
#dep = 3 if 'arm750' in path else dep
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=0, her=0)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), color='tab:blue', label='MPO')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:blue')
    except:
        print('mpo failed')
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=0, her=1)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), '--', color='tab:purple', label='HER-MPO')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field), runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:purple')
    except:
        print('hermpo  mpo failed')
    try:
        runs.reset()
        prepare_runs(runs, cut, dep=dep, her=1)
        print(f'{field} max={runs.get_max(field)}')
        ax.plot(runs.steps, runs.get_mean(field), '--', color='tab:red', label='HER-DEP-MPO')
        ax.fill_between(runs.steps, runs.get_mean(field) - runs.get_std(field),  runs.get_mean(field) + runs.get_std(field),
                        alpha=0.1, color='tab:red')
    except:
        print('herdepmpo failed')
    ax.set_xlabel('steps')
    if field == 'test_success':
        ax.set_ylim([-0.01, 1.01])
        ax.set_yticks([0, 0.5, 1.0])
    #ax.set_ylabel(YLABELS[field])
    ax.set_xlim([0, limes])


if __name__ == '__main__':
    texts = {'above_shoulder': 'C',
             'sideways': 'B',
             'chest': 'D'}
    prefix = '/folder/'
    for field in ['test_success']:
        #for path in ['ostrich_tuned_final', 'ostrich_tuned_ou', 'ostrich_paper_td4_my_parallel','ostrich_tuned_colored_corrected']:
        data = ['neurips/arm26_her/arm26-dep-her-mpo_final', 'neurips/arm750_her/arm750_optim_final_imol_switcher_working_FINAL',
                'ostrich_neurips/ostrich_foraging']
        #lims = [6.5e6, 2.5e7, 1.5e7]
        lims = [6.5e6, 2.2e7, 1.5e7]
        fig, axs = plt.subplots(1, len(data), sharey=True)
        for ax, path, limes in zip(axs, data, lims):
            print(f'data {path}')
            if 'arm750' in path:
                path = f'{prefix}{path}'
                plot(fig, ax, path, field=field, limes=limes, cut=int(limes-10000))
                path = 'neurips/arm750_her/arm750_ablations_neurips'
                path = f'{prefix}{path}'
                plot(fig, ax, path, field=field, limes=limes, cut=int(limes-10000))
                path = 'neurips/arm750_her/arm750_herdeprun'
                path = f'{prefix}{path}'
                plot(fig, ax, path, field=field, limes=limes, cut=int(limes-10000))
            elif 'ostrich' in path:
                path = f'{prefix}{path}'
                plot(fig, ax, path, field=field, limes=limes, cut=int(limes-10000))
                path = 'ostrich_neurips/ostrich_foraging_deprun/'
                path = f'{prefix}{path}'
                plot(fig, ax, path, field=field, limes=limes, cut=int(limes-10000))

            else:
                path = f'{prefix}{path}'
                plot(fig, ax, path, field=field, limes=limes, cut=int(limes - 10000))
        plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
        axs[0].set_ylabel('success rate')
        fig.savefig(f'./reaching_evaluate_{field}.pdf')
