import os

import matplotlib.pyplot as plt
from cmx import doc
from ml_logger import ML_Logger
from tqdm import tqdm

all_games = """
ball_in_cup-catch
cartpole-swingup
cheetah-run
finger-spin
reacher-easy
walker-walk
cartpole-balance
cartpole-balance_sparse
cartpole-swingup_sparse
hopper-hop
hopper-stand
pendulum-swingup
reacher-hard
walker-run
walker-stand
""".strip().split('\n')[1:]


def memoize(f):
    memo = {}

    def wrapper(*args, **kwargs):
        key = (*args, *kwargs.keys(), *kwargs.values())
        if key not in memo:
            memo[key] = f(*args, **kwargs)
        return memo[key]

    return wrapper


colors = ['#23aaff', '#ff7575', '#66c56c', '#f4b247']
loader = ML_Logger(prefix="model-free/model-free/baselines/drqv2/train")

loader.glob = memoize(loader.glob)
loader.load_pkl = memoize(loader.load_pkl)

with doc.hide:
    def plot_line(path, color, label, metrics_loader=loader, linewidth=2, linestyle=None):
        mean, top, bottom, step, = metrics_loader.read_metrics("eval/episode_reward/mean@mean",
                                                               "eval/episode_reward/mean@84%",
                                                               "eval/episode_reward/mean@16%",
                                                               x_key="step@min", bin_size=5, path=path)
        plt.plot(step.to_list(), mean.to_list(), color=color, label=label, linewidth=linewidth, linestyle=linestyle)
        plt.fill_between(step, bottom, top, alpha=0.15, color=color)

doc @ """
# DrQv2 Results (from Pixel)

Below are DrQv2 baselines from state space with the DrQ code base.
"""

with doc.table() as table:
    for i, env_name in enumerate(tqdm(all_games)):
        r = table.figure_row() if i % 5 == 0 else r
        plt.close()
        with r:
            plt.figure(figsize=(4.25, 3.5))

            plot_line(path=f"drqv2/{env_name}/**/metrics.pkl", color="purple", label="DrQ", metrics_loader=loader)
            plt.xlabel('Steps')
            # The y axis needs to be there because scales are different. Remove the "Reward" label.
            if i == 0:
                plt.ylabel('Reward')

            plt.tight_layout()
            r.savefig(f"{os.path.basename(__file__)[:-3]}/{env_name}.png",
                      title=f"{env_name}", dpi=300, bbox_inches='tight', pad_inches=0)

    else:
        plt.legend(frameon=False)
        ax = plt.gca()
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.collections.clear()
        ax.lines.clear()
        [s.set_visible(False) for s in ax.spines.values()]
        r.savefig(f"{os.path.basename(__file__)[:-3]}/legend.png", title="Legend", dpi=300,
                  bbox_inches='tight', pad_inches=0)
    plt.close()

doc.flush()
