#!/usr/bin/env python3

import argparse
import math

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import AutoMinorLocator

import latency_env.misc.argparser_types as at
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList
from latency_env.misc import ArgumentParser

# https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html
COLOR_MODIFIERS = {
    "solid": "-",
    "dotted": ":",
    "dashed": "--",
    "dashdot": "-.",
    "densedashdot": (0, (3, 1.5, 1, 1, 1, 1.5)),
}

COLOR_RGB_MAP = {
    "red":            (1.0, 0.0, 0.0, 1.0),
    "darkred":        (0.5, 0.0, 0.0, 1.0),
    "green":          (0.0, 0.7, 0.0, 1.0),
    "darkgreen":      (0.0, 0.4, 0.0, 1.0),
    "blue":           (0.0, 0.0, 1.0, 1.0),
    "yellow":         (1.0, 0.7, 0.0, 1.0),
    "darkyellow":     (0.9, 0.6, 0.0, 1.0),
    "purple":         (1.0, 0.0, 1.0, 1.0),
    "darkpurple":     (0.7, 0.0, 0.7, 1.0),
    "cyan":           (0.0, 0.7, 0.8, 1.0),
    "orange":         (1.0, 0.4, 0.0, 1.0),
    "darkorange":     (0.9, 0.35, 0.0, 1.0),
    "teal":           (0.0, 0.7, 0.4, 1.0),
}

COLOR_MAP = {}
for c_name, (r, g, b, a) in COLOR_RGB_MAP.items():
    COLOR_MAP[c_name] = (r, g, b, a, COLOR_MODIFIERS["solid"])
    for mod, modval in COLOR_MODIFIERS.items():
        COLOR_MAP[f"{mod}-{c_name}"] = (r, g, b, a, modval)

# Setup color mixes
# Using this method here: https://stackoverflow.com/questions/726549/algorithm-for-additive-color-mixing-for-rgb-values
_OG_ITEMS = list(COLOR_RGB_MAP.items())
for i, (c1_name, (r1, g1, b1, a1)) in enumerate(_OG_ITEMS):
    for c2_name, (r2, g2, b2, a2) in _OG_ITEMS[i+1:]:
        def blend_color_value(cv1, cv2, q=0.5):
            return math.sqrt((1 - q) * cv1 * cv1 + q * cv2 * cv2)
        def blend_alpha_value(av1, av2, q=0.5):
            return (1 - q) * av1 + q * av2

        r = blend_color_value(r1, r2)
        g = blend_color_value(g1, g2)
        b = blend_color_value(b1, b2)
        a = blend_alpha_value(a1, a2)

        c_name = f"{c1_name}-{c2_name}"
        c_revname = f"{c2_name}-{c1_name}"
        COLOR_MAP[c_name] = (r, g, b, a, COLOR_MODIFIERS["solid"])
        COLOR_MAP[c_revname] = (r, g, b, a, COLOR_MODIFIERS["solid"])
        for mod, modval in COLOR_MODIFIERS.items():
            COLOR_MAP[f"{mod}-{c_name}"] = (r, g, b, a, modval)
            COLOR_MAP[f"{mod}-{c_revname}"] = (r, g, b, a, modval)


def parse_input_json(s : str):
    import json
    idx = s.find(":")
    if idx < 0:
        label = None
        color = None
        path = s
    else:
        label = s[:idx]
        path = s[idx+1:]
        cix = label.find("%")
        if cix < 0:
            color = None
        else:
            color = label[cix+1:]
            label = label[:cix]

    with open(path) as f:
        data = json.load(f)

    if label is None:
        label = data["trainer_prefix"]

    return (label, color, data)

parser = ArgumentParser()
parser += ArgList(
    jsondata = Arg(metavar="[label[%color]:]path", type=parse_input_json, nargs="+",
                   help="Input JSON file(s) to plot data from. "
                        "Example format: furuta:_logs/run1234/eval.json"),
    fontsize = Arg("--fontsize", type=at.posint, default=None, help="Custom font size for all text."),
    plot_std = Arg("--plot-std", action=argparse.BooleanOptionalAction, default=False,
                   help="Plot standard deviations."),
    running_average_n = Arg("--runavg", "--running-average", metavar="N", type=at.posint, default=1,
                            help="For how many steps to plot running average for."),
    show_legend = Arg("--legend", action=argparse.BooleanOptionalAction, default=True,
                      help="Show legend."),
    flat_legend = Arg("--flat-legend", action=argparse.BooleanOptionalAction, default=False,
                      help="Show a flat legend below the plot instead."),
    save_legend_to_file = Arg("--save-legend", action=argparse.BooleanOptionalAction, default=False,
                              help="Don't show the plot, only save the legend to file."),
    output_file = Arg("-o", "--output-file", metavar="FILE", type=str, default=None,
                      help="Specific output file name. (Default: show only)"),
    title = Arg("-t", "--title", type=str, default=None,
                help="Title of the generated plot."),
)
args = parser.parse_args()

label_kwargs = {}
if args.fontsize is not None:
    label_kwargs |= {"fontsize": args.fontsize}

label_colors = {label: color for (label, color, _) in args.jsondata if color is not None}
latency_evals = {}
for label, color, data in args.jsondata:
    if color is None:
        available_colors = set(COLOR_MAP.keys()) - set(label_colors.values())
        label_colors[label] = sorted(available_colors)[0]
    for e in data["evaluations"]:
        for latency, ed in e["latency_eval"].items():
            if latency not in latency_evals:
                latency_evals[latency] = {}
            if "returns" in ed:
                if label not in latency_evals[latency]:
                    latency_evals[latency][label] = {
                        "ticks": [],
                        "returns": [],
                        "lengths": [],
                    }
                latency_evals[latency][label]["ticks"].append(e["iterations"])
                latency_evals[latency][label]["returns"].append(ed["returns"])
                latency_evals[latency][label]["lengths"].append(ed["lengths"])
            elif "return_avg" in ed:
                # Special case for dreamer results
                if label not in latency_evals[latency]:
                    latency_evals[latency][label] = {
                        "ticks": [],
                        "return_avgs": [],
                        "return_stds": [],
                    }
                latency_evals[latency][label]["ticks"].append(e["iterations"])
                latency_evals[latency][label]["return_avgs"].append(ed["return_avg"])
                latency_evals[latency][label]["return_stds"].append(ed["return_std"])
            else:
                raise RuntimeError(f"unsupported {e}")

colors = [
    (1.0, 0.0, 0.0, 1.0), # red
    (0.0, 0.7, 0.0, 1.0), # green
    (0.0, 0.0, 1.0, 1.0), # blue
    (1.0, 0.7, 0.0, 1.0), # yellow
    (1.0, 0.0, 1.0, 1.0), # purple
    (0.0, 0.7, 0.8, 1.0), # cyan
]

for i, (latency, algdata) in enumerate(latency_evals.items()):
    #plt.cla()
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(8, 6)
    print(f"Plotting for {latency}")
    for j, (label, ld) in enumerate(algdata.items()):
        X = np.array(ld["ticks"], dtype=np.int64)

        if "returns" in ld:
            Y_data = np.array(ld["returns"], dtype=np.float64)
            Y_mean = Y_data.mean(axis=1)
            Y_std = Y_data.std(axis=1)
        elif "return_avgs" in ld:
            Y_mean = np.array(ld["return_avgs"], dtype=np.float64)
            Y_std = np.array(ld["return_stds"], dtype=np.float64)

        running_avg = lambda a, n: np.array([a[max(0, (idx + 1) - n):idx + 1].sum() / min(idx + 1, n) for idx in range(a.shape[0])])
        Y_mean = running_avg(Y_mean, args.running_average_n)
        Y_std = running_avg(Y_std, args.running_average_n)

        Y_upperstd = Y_mean + Y_std
        Y_lowerstd = Y_mean - Y_std


        color = COLOR_MAP[label_colors[label]] #colors[j]
        fill_color = tuple(color[:3]) + (color[3] * 0.2 * 0.5,)
        edge_color = tuple(color[:3]) + (color[3] * 0.4 * 0,) # Temp: no edge color
        linestyle = color[4]

        if args.plot_std:
            ax.fill_between(X, Y_lowerstd, Y_upperstd, color=fill_color, edgecolor=edge_color)
        ax.plot(X, Y_mean, label=label, color=color[:4], linestyle=linestyle)

    if not args.save_legend_to_file:
        ax.set_xlabel("Steps", **label_kwargs)
        ax.set_ylabel("Return", **label_kwargs)
        if args.title is not None:
            ax.set_title(f"{args.title}")

        ax.xaxis.set_minor_locator(AutoMinorLocator(n=4))
        ax.yaxis.set_minor_locator(AutoMinorLocator(n=4))

        ax.grid(which="major", alpha=0.8)
        ax.grid(which="minor", alpha=0.3)

    if args.save_legend_to_file:
        #box = ax.get_position()
        #ax.set_position([box.x0 - box.width * 0.0, box.y0,
        #                 box.width * 0.825,          box.height])
        if args.flat_legend:
            legend = ax.legend(loc="lower center", ncol=len(algdata), bbox_to_anchor=(0.5, -0.5))
        else:
            legend = ax.legend(loc="upper right", frameon=False, bbox_to_anchor=(1.85, 1.025))
        #ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, shadow=True, ncol=len(algdata))
    elif args.show_legend:
        box = ax.get_position()
        if args.flat_legend:
            ax.legend(loc="lower center", ncol=len(algdata))
        else:
            ax.set_position([box.x0 - box.width * 0.0, box.y0,
                             box.width * 0.825,        box.height])
            ax.legend(loc="upper right", bbox_to_anchor=(1.35, 1.025))

    if args.output_file is not None:
        oidx = args.output_file.rfind(".")
        oname, oending = (args.output_file, ".pdf") if oidx < 0 else (args.output_file[:oidx], args.output_file[oidx:])
        if args.fontsize is not None:
            oname += f"-fontsize{int(args.fontsize)}"
        outname = f"{oname}{oending}" #f"{oname}-L{latency}{oending}"
        if args.save_legend_to_file:
            print(f"Saved legend to file: {outname}")
            fig = legend.figure
            fig.canvas.draw()
            bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
            fig.savefig(outname, dpi="figure", bbox_inches=bbox)
        else:
            print(f"Saved plot to file: {outname}")
            plt.savefig(outname, bbox_inches="tight")
    else:
        plt.show()
