import os

import matplotlib.pyplot as plt
from cmx import doc
from ml_logger import ML_Logger
from tqdm import tqdm

soda_envs = ['Walker-walk', 'Walker-stand', 'Cartpole-swingup', 'Ball_in_cup-catch', 'Finger-spin']
extra_envs = ['Reacher-easy', 'Cheetah-run', 'Cartpole-balance', 'Finger-turn_easy']
all_envs = soda_envs + extra_envs
all_envs = [e.lower() for e in all_envs]
algorithms = ['sac', 'soda', 'pad', 'svea']


def memoize(f):
    memo = {}

    def wrapper(*args, **kwargs):
        key = (*args, *kwargs.keys(), *kwargs.values())
        if key not in memo:
            memo[key] = f(*args, **kwargs)
        return memo[key]

    return wrapper


colors = ['#23aaff', '#ff7575', '#66c56c', '#f4b247']
loader = ML_Logger(prefix="model-free/model-free/baselines/dmc_gen/run")

loader.glob = memoize(loader.glob)
loader.load_pkl = memoize(loader.load_pkl)

def plot_line(path, color, label, metrics_loader=loader, linewidth=2, linestyle=None):
    mean, top, bottom, step, = metrics_loader.read_metrics("train/episode_reward/mean@mean",
                                                           "train/episode_reward/mean@84%",
                                                           "train/episode_reward/mean@16%",
                                                           x_key="step@min", bin_size=5, path=path)
    plt.plot(step.to_list(), mean.to_list(), color=color, label=label, linewidth=linewidth, linestyle=linestyle)
    plt.fill_between(step, bottom, top, alpha=0.15, color=color)

doc @ """
# Baseline training results

These algorithms are taken from Nick Hansen's [dmcontrol-generalization-benchmark](https://github.com/nicklashansen/dmcontrol-generalization-benchmark) repo.
"""

with doc.table() as table:
    for i, env_name in enumerate(tqdm(all_envs)):
        r = table.figure_row() if i % 5 == 0 else r
        plt.close()
        with r:
            plt.figure(figsize=(4.25, 3.5))

            for j, alg in enumerate(algorithms):
                plot_line(path=f"{alg}/{env_name}/**/metrics.pkl", color=colors[j], label=f"{alg}", metrics_loader=loader)
            plt.xlabel('Steps')
            # The y axis needs to be there because scales are different. Remove the "Reward" label.
            if i == 0:
                plt.ylabel('Reward')

            plt.tight_layout()
            r.savefig(f"{os.path.basename(__file__)[:-3]}/{env_name}.png",
                      title=f"{env_name}", dpi=300, bbox_inches='tight', pad_inches=0)

    else:
        plt.legend(frameon=False)
        ax = plt.gca()
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.collections.clear()
        ax.lines.clear()
        [s.set_visible(False) for s in ax.spines.values()]
        r.savefig(f"{os.path.basename(__file__)[:-3]}/legend.png", title="Legend", dpi=300,
                  bbox_inches='tight', pad_inches=0)
plt.close()
doc.flush()
