#!/usr/bin/env python3

import argparse
import json
import math
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
import os
from matplotlib.ticker import AutoMinorLocator

import latency_env.misc.argparser_types as at
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList
from latency_env.misc import ArgumentParser


DEFAULT_WIDTH = 8.0
DEFAULT_HEIGHT = 6.0

# https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html
COLOR_MODIFIERS = {
    "solid": "-",
    "dotted": ":",
    "dashed": "--",
    "dashdot": "-.",
    "densedashdot": (0, (3, 1.5, 1, 1, 1, 1.5)),
}

COLOR_RGB_MAP = {
    "red":            (1.0, 0.0, 0.0, 1.0),
    "darkred":        (0.5, 0.0, 0.0, 1.0),
    "green":          (0.0, 0.7, 0.0, 1.0),
    "darkgreen":      (0.0, 0.4, 0.0, 1.0),
    "blue":           (0.0, 0.0, 1.0, 1.0),
    "yellow":         (1.0, 0.7, 0.0, 1.0),
    "darkyellow":     (0.9, 0.6, 0.0, 1.0),
    "purple":         (0.8, 0.0, 1.0, 1.0),
    "darkpurple":     (0.7, 0.0, 0.7, 1.0),
    "pink":           (0.9, 0.2, 0.9, 1.0),
    "cyan":           (0.0, 0.7, 0.8, 1.0),
    "orange":         (1.0, 0.4, 0.0, 1.0),
    "darkorange":     (0.8, 0.3, 0.0, 1.0),
    "teal":           (0.0, 0.7, 0.4, 1.0),
}

COLOR_MAP = {}
for c_name, (r, g, b, a) in COLOR_RGB_MAP.items():
    COLOR_MAP[c_name] = (r, g, b, a, COLOR_MODIFIERS["solid"])
    for mod, modval in COLOR_MODIFIERS.items():
        COLOR_MAP[f"{mod}-{c_name}"] = (r, g, b, a, modval)


NOISE_COLORS = {
    0: "red",
    25: "darkgreen",
    50: "blue",
    75: "pink",
    100: "dashed-red",
    150: "dashed-cyan",
    200: "dashed-darkpurple",
}


parser = ArgumentParser()
parser += ArgList(
    bm_prefix = Arg("--bm-prefix", metavar="prefix", type=str, required=True,
                            help="Example: noise-eval-v2/ant/bpql-noise"),
    noise_promilles = Arg("--noise-promilles", type=at.nonnegint_commalist,
                     required=True, help="Example: --noise-promilles 0,25,50,75,100,150,200"),
    delays = Arg("--delays", type=at.nonnegint_commalist,
                     required=True, help="Example: --delays 3,6,9,12,15,18,21,24"),
    indentation = Arg("--indent", type=at.nonnegint, default=0,
                      help="Indentation of output."),
    gen_table = Arg("--gen-table", action=argparse.BooleanOptionalAction, default=False,
                    help="Generate a table with data instead of a plot."),
    show_legend = Arg("--legend", action=argparse.BooleanOptionalAction, default=True,
                      help="Show legend."),
    flat_legend = Arg("--flat-legend", action=argparse.BooleanOptionalAction, default=False,
                      help="Show a flat legend below the plot instead."),
    save_legend_to_file = Arg("--save-legend", action=argparse.BooleanOptionalAction, default=False,
                              help="Don't show the plot, only save the legend to file."),
    output_file = Arg("-o", "--output-file", type=str, default=None,
                      help="Specific output file name. (Default: show only)"),
    width = Arg("--width", type=at.posint, default=DEFAULT_WIDTH, help="Width of the figure in inches."),
    height = Arg("--height", type=at.posint, default=DEFAULT_HEIGHT, help="Height of the figure in inches."),
    fontsize = Arg("--fontsize", type=at.posint, default=None, help="Custom font size for all text."),
)
args = parser.parse_args()

label_kwargs = {}
if args.fontsize is not None:
    label_kwargs |= {"fontsize": args.fontsize}

lookup = {}
for noise in args.noise_promilles:
    lookup[noise] = {}
    for delay in args.delays:

        json_path = f"{args.bm_prefix}_0.{noise:03d}-delay_{delay}/eval.json"
        try:
            with open(json_path) as f:
                data = json.load(f)
        except:
            print(f"WARNING: skipping noise={noise} delay={delay}")
            continue

        rets = []
        for e in data["evaluations"]:
            for latency, ed in e["latency_eval"].items():
                ret = np.array(ed["returns"], dtype=np.float64)
                rets.append((ret.mean(), ret.std()))

        rets = sorted(rets, key=lambda k: k[0])

        lookup[noise][delay] = {
            "best_mean": rets[-1][0],
            "best_std": rets[-1][1],
            "median_mean": rets[int(len(rets) * 0.5)][0],
            "median_std": rets[int(len(rets) * 0.5)][1],
            "90%_mean": rets[int(len(rets) * 0.9)][0],
            "90%_std": rets[int(len(rets) * 0.9)][1],
        }


if args.gen_table:
    lines = []
    lines += [
        r"\begin{adjustbox}{max width=\textwidth}",
        r"\begin{tblr}{",
        r"  colspec = {",
        r"      Q[l,t]",
    ]
    lines += [r"     |Q[r,t]Q[c,t]Q[r,t]"] * len(args.delays)
    lines += [
        r"  },",
        r"  colsep = 6pt,",
        rf"  cell{{1}}{{{','.join([str((i * 3) + 2) for i in range(len(args.delays))])}}} = {{c=3}}{{c}},",
        rf"  column{{{','.join([str((i * 3) + 2) for i in range(len(args.delays))])}}} = {{rightsep=0pt}},",
        rf"  column{{{','.join([str((i * 3) + 3) for i in range(len(args.delays))])}}} = {{colsep=3pt}},",
        rf"  column{{{','.join([str((i * 3) + 4) for i in range(len(args.delays))])}}} = {{leftsep=0pt}},",
        r"}",
        r"\hline",
    ]

    for delay in args.delays:
        lines += [rf"& \SetCell[c=3]{{c}} Delay {delay} &&"]
    lines += [r"\\\hline"]

    for i, noise in enumerate(args.noise_promilles):
        lines += [f"{noise/10.0}\\% noise"]
        for delay in args.delays:
            d = lookup[noise][delay]

            d_mean, d_std = (d["best_mean"], d["best_std"])
            is_best = all(lookup[n2][delay]["best_mean"] <= d_mean for n2 in args.noise_promilles)
            if is_best:
                lines += [rf"& $\bm{{{d_mean:.2f}}}$ & $\bm{{\pm}}$ & $\bm{{{d_std:.2f}}}$ % Delay {delay}"]
            else:
                lines += [rf"& ${d_mean:.2f}$ & $\pm$ & ${d_std:.2f}$ % Delay {delay}"]

        if i < (len(args.noise_promilles) - 1):
            lines += [r"\\\hline[dotted,fg=black!50]"]

    lines += [
        r"\end{tblr}",
        r"\end{adjustbox}",
    ]

    print("\n".join([
        (" " * args.indentation) + l for l in lines
    ]))
else:
    ticks = list(sorted(args.delays))

    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(8, 6)

    for noise in args.noise_promilles:
        X = np.array(sorted(lookup[noise].keys()), dtype=np.int64)

        Y_mean = np.array([lookup[noise][delay]["best_mean"] for delay in X], dtype=np.float64)
        Y_std = np.array([lookup[noise][delay]["best_std"] for delay in X], dtype=np.float64)

        Y_upperstd = Y_mean + Y_std
        Y_lowerstd = Y_mean - Y_std

        color = COLOR_MAP[NOISE_COLORS[noise]]
        fill_color = tuple(color[:3]) + (color[3] * 0.2 * 0.5,)
        edge_color = tuple(color[:3]) + (color[3] * 0.4 * 0,) # Temp: no edge color
        linestyle = color[4]

        ax.fill_between(X, Y_lowerstd, Y_upperstd, color=fill_color, edgecolor=edge_color)
        ax.plot(X, Y_mean, label=f"{noise/10.0}%", color=color[:4], linestyle=linestyle)

    if not args.save_legend_to_file:
        ax.set_xlabel("Delay", **label_kwargs)
        ax.set_ylabel("Return", **label_kwargs)

        ax.xaxis.set_minor_locator(AutoMinorLocator(n=4))
        ax.yaxis.set_minor_locator(AutoMinorLocator(n=4))

        ax.grid(which="major", alpha=0.8)
        ax.grid(which="minor", alpha=0.3)

    if args.save_legend_to_file:
        if args.flat_legend:
            legend = ax.legend(loc="lower center", ncol=len(args.noise_promilles), bbox_to_anchor=(0.5, -0.5))
        else:
            legend = ax.legend(loc="upper right", frameon=False, bbox_to_anchor=(1.85, 1.025))
    elif args.show_legend:
        box = ax.get_position()
        if args.flat_legend:
            ax.legend(loc="lower center", ncol=len(args.noise_promilles))
        else:
            ax.set_position([box.x0 - box.width * 0.0, box.y0,
                             box.width * 0.825,        box.height])
            ax.legend(loc="upper right", bbox_to_anchor=(1.30, 1.015))

    if args.output_file is not None:
        oidx = args.output_file.rfind(".")
        oname, oending = (args.output_file, ".pdf") if oidx < 0 else (args.output_file[:oidx], args.output_file[oidx:])
        if args.fontsize is not None:
            oname += f"-fontsize{int(args.fontsize)}"
        outname = f"{oname}{oending}" #f"{oname}-L{latency}{oending}"
        if args.save_legend_to_file:
            print(f"Saved legend to file: {outname}")
            fig = legend.figure
            fig.canvas.draw()
            bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
            fig.savefig(outname, dpi="figure", bbox_inches=bbox)
        else:
            print(f"Saved plot to file: {outname}")
            plt.savefig(outname, bbox_inches="tight")
    else:
        plt.show()
