from pathlib import Path

from cmx import doc
import numpy as np

class Memoize:
    def __init__(self, func):
        self._memo = {}
        self._func = func

    def __call__(self, *args, **kwargs):
        key = (*args, *kwargs.keys(), *kwargs.values())
        if key not in self._memo:
            self._memo[key] = self._func(*args, **kwargs)
        return self._memo[key]

    def save(self, key='memoize.pkl'):
        import torch
        print('storing to memo...')
        torch.save(self._memo, key)

    def load(self, key='memoize.pkl'):
        import torch
        self._memo = torch.load(key)


plot_colors = ['#23aaff', '#ff7575', '#66c56c', '#f4b247']
colors = {
    'drqv2': ('#23aaff', '#777777'),
    'sac': ('#ff7575', '#777777'),
    'svea': ('#66c56c', '#777777'),
    'soda': ('#11e56c', '#777777')
}


def read_metrics(prefix, alg, distr_type, suffix, y_key='eval/episode_reward'):
    from ml_logger import ML_Logger
    src_logger = ML_Logger(prefix=prefix)
    try:
        mean, std, step = src_logger.read_metrics(
            y_key + '@mean',
            y_key + '@std',
            x_key="step@min",
            path=suffix,
            ignore_errors=True
        )
    except Exception as e:
        print('Error captured:', e, suffix)
        return None, None, None

    return mean, std, step


def plot_line(alg, distr_type, intensities, return_mean, zs_return_mean):
    import matplotlib.pyplot as plt

    ymax = 1200
    ymin = 0
    color, zs_color = colors[alg]
    print('alg', alg, 'distr', distr_type, 'return_mean', list(return_mean))
    plt.ylim(ymax=ymax, ymin=ymin)
    plt.plot(intensities, return_mean, color=color, label=alg, linewidth=2, linestyle=None, marker='.')
    plt.plot(intensities, zs_return_mean, color=zs_color, label='zeroshot_' + alg, linewidth=2, linestyle=None, marker='.')


if __name__ == '__main__':
    import pandas as pd
    import matplotlib.pyplot as plt
    from params_proto.neo_hyper import Sweep
    from pathlib import Path

    import os
    from collections import defaultdict
    from os.path import join as pJoin

    algorithm_prefixes = {
        'drqv2': 'model-free/model-free/drqv2_invariance/iclr2022prep/adaptation/run/drqv2',
        'sac': 'model-free/model-free/invr_thru_inf/adapt/dmcgen/run/sac',
        'svea': 'model-free/model-free/invr_thru_inf/adapt/dmcgen/run/svea',
        'soda': 'model-free/model-free/invr_thru_inf/adapt/dmcgen/run/soda',
        # 'pad': 'model-free/model-free/invr_thru_inf/adapt/dmcgen/run/pad'
    }

    # Example of the suffix:
    # separate_buffers/random-random/aug/finger-spin/rep4/col-intsty0.5000/f5460b4253/100
    # separate_buffers/random-random/aug/walker-walk/col-intsty0.5000/e1692cfa62/300
    distr_suffixes = {
        'background': 'separate_buffers/random-random/aug/*/**/bac-intsty{intensity:.4f}/**/metrics.pkl',
        'color': 'separate_buffers/random-random/aug/*/**/col-intsty{intensity:.4f}/**/metrics.pkl',
        'camera': 'separate_buffers/random-random/aug/*/**/cam-intsty{intensity:.4f}/**/metrics.pkl',
        # 'all': (
        #     'separate_buffers/random-random/aug/*/**/bac-intsty{intensity:.4f}/**/metrics.pkl',
        #     'separate_buffers/random-random/aug/*/**/col-intsty{intensity:.4f}/**/metrics.pkl',
        #     'separate_buffers/random-random/aug/*/**/cam-intsty{intensity:.4f}/**/metrics.pkl'
        # ),
    }
    intensity_coef = {'background': 1.0, 'color': 0.5, 'camera': 0.25}
    intensities = list(np.linspace(0, 1, 5 + 1))[1:]

    entries = defaultdict(list)
    csv_file = 'plot_data.csv'

    read_metrics = Memoize(read_metrics)

    if os.path.isfile('memoize.pkl'):
        read_metrics.load()

    if not os.path.isfile(csv_file):
        # Loop over every point in intensity-vs-return plot
        for alg, alg_prefix in algorithm_prefixes.items():
            for distr_type, distr_suffix in distr_suffixes.items():
                for x, intensity in enumerate(intensities):
                    if isinstance(distr_suffix, (list, tuple)):
                        suffix = tuple([e.format(intensity=intensity*coef) for e, coef in zip(distr_suffix, (1.0, 0.5, 0.25))])
                    else:
                        suffix = distr_suffix.format(intensity=intensity*intensity_coef[distr_type])
                    print('suffix', suffix)
                    mean, std, step = read_metrics(
                        alg_prefix, alg, distr_type, suffix
                    )

                    if mean is None:
                        print('mean is None!!', alg_prefix, suffix)
                        continue

                    # Add entry
                    entries['algorithm'].append(alg)
                    entries['distraction'].append(distr_type)
                    entries['intensity'].append(intensity)
                    entries['return_mean'].append(mean[-1])  # TODO: Should take max with argmax (for std)
                    entries['return_std'].append(std[-1])  # TODO: use argmax(mean) to pick the correct std
                    entries['zeroshot_return_mean'].append(mean[0])
                    entries['zeroshot_return_std'].append(std[0])
                    read_metrics.save()
        df = pd.DataFrame.from_dict(entries)
        df.to_csv(csv_file)
    else:
        df = pd.read_csv(csv_file)

    # Plot lines from data frame
    with doc.table() as table:
        r = table.figure_row()
        for column, distr_type in enumerate(sorted(distr_suffixes.keys())):
            with r:
                plt.figure(figsize=(4.25, 3.5))
                plt.xlabel('Intensity')
                plt.tight_layout()
                if column == 0:
                    plt.ylabel('Episode return')
                for alg in algorithm_prefixes.keys():
                    df_ = df[df['distraction'] == distr_type][df['algorithm'] == alg]
                    intensities = df_['intensity']
                    return_mean, return_std = df_['return_mean'], df_['return_std']
                    zs_return_mean, zs_return_std = df_['zeroshot_return_mean'], df_['zeroshot_return_std']

                    plot_line(alg, distr_type, intensities, return_mean, zs_return_mean)

                r.savefig(f"{Path(__file__).stem}/{distr_type}.png", title=f"{distr_type}", dpi=300, bbox_inches='tight', pad_inches=0)
                plt.savefig(f"{Path(__file__).stem}/{distr_type}.pdf")
        else:
            # Finally generate the legend
            plt.figure(figsize=(4.25, 3.5))
            plt.legend(frameon=False)
            ax = plt.gca()
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.collections.clear()
            ax.lines.clear()
            [s.set_visible(False) for s in ax.spines.values()]
            r.savefig(f"{Path(__file__).stem}/legend.png", title=f"legend", dpi=300, bbox_inches='tight', pad_inches=0)
            plt.savefig(f"{Path(__file__).stem}/legend.pdf")

    plt.close()
    doc.flush()
