import argparse
import collections
import functools
import gzip
import json
import multiprocessing as mp
import pathlib
import re
import subprocess
import warnings

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
import rich.console
import tqdm

TITLES = {
    "dmlab_explore_goal_locations_small": "DMLab Goals Small",
    "crafter_reward": "Crafter",
    "pinpad2_three": "Pin Pad Three",
    "pinpad2_four": "Pin Pad Four",
    "pinpad2_five": "Pin Pad Five",
    "pinpad2_six": "Pin Pad Six",
    "pinpad2_eight": "Pin Pad Eight",
    "loconav_ant_maze_s_50hz": "Ant Maze S",
    "loconav_ant_maze_m_50hz": "Ant Maze M",
    "loconav_ant_maze_l_50hz": "Ant Maze L",
    "loconav_ant_maze_xl_50hz": "Ant Maze XL",
}

COLORS = {
    "contrast": (
        "#0022ff",
        "#33aa00",
        "#ff0011",
        "#ddaa00",
        "#cc44dd",
        "#0088aa",
        "#001177",
        "#117700",
        "#990022",
        "#885500",
        "#553366",
        "#006666",
    ),
    "gradient": ("#a0da39", "#4ac16d", "#277f8e", "#365c8d", "#46327e", "#440154"),
    "gradient_more": (
        "#fde725",
        "#a0da39",
        "#4ac16d",
        "#1fa187",
        "#277f8e",
        "#365c8d",
        "#46327e",
        "#440154",
    ),
}


def main():
    console = rich.console.Console()
    args = parse_args()
    runs = []
    for directory in args.indirs:
        seed_prefix = len(args.indirs) > 1 and directory.name
        method_prefix = args.prefix and directory.name
        runs += load_metrics(
            directory,
            args.pattern,
            args.xaxis,
            args.yaxis,
            args.yaxis2,
            seed_prefix,
            method_prefix,
            args.tasks,
            args.methods,
            args.workers,
        )
    tasks = []
    for regex in args.tasks:
        found = [x["task"] for x in runs if re.search(regex, x["task"])]
        [tasks.append(x) for x in natsort(found) if x not in tasks]
    methods = []
    for regex in args.methods:
        found = [x["method"] for x in runs if re.search(regex, x["method"])]
        [methods.append(x) for x in natsort(found) if x not in methods]
    seeds = natsort(set(run["seed"] for run in runs))
    console.print(f'Tasks ({len(tasks)}): [cyan]{", ".join(tasks)}[/cyan]')
    console.print(f'Methods ({len(methods)}): [cyan]{", ".join(methods)}[/cyan]')
    console.print(f'Seed ({len(seeds)}): [cyan]{", ".join(seeds)}[/cyan]')
    if not runs:
        console.print("Nothing to plot!", style="red")
        return
    args.outdir.mkdir(parents=True, exist_ok=True)

    if args.stats:
        print("Computing stats...", flush=True)
        len(tasks) == 1 and "mean" in args.stats and args.stats.remove("mean")
        len(tasks) == 1 and "median" in args.stats and args.stats.remove("median")
        extra_runs, extra_tasks = compute_stats(runs, args.stats, args.bins)
        runs += extra_runs
        tasks += extra_tasks

    print("Binning runs...", flush=True)
    if args.bins:
        maxs = collections.defaultdict(list)
        for run in runs:
            maxs[(run["task"], run["method"])].append(run["xs"].max())
        maxs = {k: max(vs) for k, vs in maxs.items()}
        for run in runs:
            if run["task"].startswith("stats_"):
                continue
            max_ = maxs[(run["task"], run["method"])] + 1e-8
            max_ = min(max_, args.xlim[1]) if args.xlim else max_
            step = max(1e-8, max_ / 30) if args.bins < 0 else args.bins
            borders = np.arange(0, max_, step)
            xs, ys = binning(run["xs"], run["ys"], borders, np.nanmean, fill="nan")
            run["xs"], run["ys"] = xs, ys

    print("Saving runs...", flush=True)
    filename = args.outdir / "runs.json.gz"
    with gzip.open(filename, "w") as f:
        f.write(
            json.dumps(
                [
                    {**run, "xs": run["xs"].tolist(), "ys": run["ys"].tolist()}
                    for run in runs
                ]
            ).encode("utf-8")
        )
    console.print(f"Saved [green]{filename}[/green]")

    print("Plotting...", flush=True)
    fig, axes = plots(len(tasks), args.cols, args.size)
    for task, ax in zip(tasks, axes):
        title = TITLES.get(task, task.split("_", 1)[1].replace("_", " ").title())
        ax.set_title(title)
        if not task.startswith("stats_"):
            args.xlim and ax.set_xlim(*args.xlim)
            args.ylim and ax.set_ylim(*args.ylim)
            args.xticks and ax.set_xticks(args.xticks)
        ax.xaxis.set_major_formatter(smart_format)
        # ax.tick_params(axis='both', labelsize=7)  # TOFO
    for task, ax in zip(tasks, axes):
        for i, method in enumerate(methods):
            relevant = [
                run for run in runs if run["task"] == task and run["method"] == method
            ]
            if not relevant:
                console.print(f"Missing {method} on {task}!", style="red")
                continue
            if args.bins and args.agg:
                groups = [relevant]
            else:
                groups = [[run] for run in relevant]
            for group in groups:
                xs = group[0]["xs"]
                ys = np.stack([run["ys"] for run in group], 0)
                mean = reduce(ys, np.nanmean, 0)
                std = reduce(ys, np.nanstd, 0)
                curve(
                    ax,
                    xs,
                    mean,
                    mean - std,
                    mean + std,
                    label=args.labels.get(method, method),
                    order=i,
                    color=args.colors(i),
                )
    legendcols = args.legendcols or min(4, args.cols, len(axes))
    legend(fig, adjust=True, ncol=legendcols)
    if args.stats:
        for ax in axes[-len(extra_tasks) :]:
            ax.set_facecolor((0.9, 0.9, 0.9))
    save(fig, args.outdir / "curves.png")
    save(fig, args.outdir / "curves.pdf")


def compute_stats(runs, stats, bins):
    extra_runs = []
    select = lambda baselines, name: {
        k: v[name] for k, v in baselines.items() if name in v
    }
    for stats in stats:
        if stats == "tasks":
            extra_runs += stats_num_tasks(runs, bins)
        elif stats == "mean":
            extra_runs += stats_self_norm(runs, bins, "mean", np.nanmean)
        elif stats == "median":
            extra_runs += stats_self_norm(runs, bins, "median", np.nanmedian)
        elif stats == "atari_mean":
            path = pathlib.Path("~/scores/atari_baselines.json").expanduser()
            baselines = json.loads(path.read_text())
            mins = select(baselines, "random")
            maxs = select(baselines, "human_gamer")
            extra_runs += stats_fixed_norm(
                runs, bins, mins, maxs, "gamer_mean", np.nanmean
            )
        elif stats == "atari_median":
            path = pathlib.Path("~/scores/atari_baselines.json").expanduser()
            baselines = json.loads(path.read_text())
            mins = select(baselines, "random")
            maxs = select(baselines, "human_gamer")
            extra_runs += stats_fixed_norm(
                runs, bins, mins, maxs, "gamer_median", np.nanmedian
            )
        elif stats == "atari_record":
            path = pathlib.Path("~/scores/atari_baselines.json").expanduser()
            baselines = json.loads(path.read_text())
            mins = select(baselines, "random")
            maxs = select(baselines, "human_record")
            extra_runs += stats_fixed_norm(
                runs, bins, mins, maxs, "record_mean", np.nanmean
            )
        elif stats == "atari_record_clip":
            path = pathlib.Path("~/scores/atari_baselines.json").expanduser()
            baselines = json.loads(path.read_text())
            mins = select(baselines, "random")
            maxs = select(baselines, "human_record")
            extra_runs += stats_fixed_norm(
                runs,
                bins,
                mins,
                maxs,
                "record_mean_clip",
                lambda x, a: np.nanmean(np.minimum(x, 1), a),
            )
        elif stats == "dmlab_mean":
            path = pathlib.Path("~/scores/dmlab_baselines.json").expanduser()
            baselines = json.loads(path.read_text())
            mins = select(baselines, "random")
            maxs = select(baselines, "human")
            extra_runs += stats_fixed_norm(
                runs,
                bins,
                mins,
                maxs,
                "human_mean",
                lambda vals, axis: np.nanmean(np.minimum(vals, 1), axis),
            )
        else:
            raise NotImplementedError(stats)
    extra_tasks = natsort(set(run["task"] for run in extra_runs))
    return extra_runs, extra_tasks


def stats_self_norm(runs, bins, name="mean", aggregator=np.nanmean):
    methods = natsort(set(run["method"] for run in runs))
    seeds = natsort(set(run["seed"] for run in runs))
    lengths, mins, maxs = {}, {}, {}
    for run in runs:
        lengths[run["task"]] = max(lengths.get(run["task"], 0), max(run["xs"]))
        mins[run["task"]] = min(mins.get(run["task"], np.inf), min(run["ys"]))
        maxs[run["task"]] = max(maxs.get(run["task"], -np.inf), max(run["ys"]))
    if bins <= 0:
        borders = {
            task: np.linspace(0, length + 1e-8, 30) for task, length in lengths.items()
        }
    else:
        border = np.arange(0, max(lengths.values()) + 1e-8, bins)
        borders = {task: border for task, length in lengths.items()}
    extra_runs = []
    for method in methods:
        for seed in seeds:
            scores = []
            for run in runs:
                if not (run["method"] == method and run["seed"] == seed):
                    continue
                task = run["task"]
                if np.isclose(mins[task], maxs[task]):
                    continue
                _, ys = binning(
                    run["xs"], run["ys"], borders[task], np.nanmean, fill="last"
                )
                scores.append((ys - mins[task]) / (maxs[task] - mins[task]))
            if scores:
                scores = np.array(scores)
                xs = np.linspace(0, 1, len(scores[0]))
                extra_runs.append(
                    {
                        "task": f"stats_normalized_{name}",
                        "method": method,
                        "seed": seed,
                        "xs": xs,
                        "ys": reduce(scores, aggregator, 0),
                    }
                )
    return extra_runs


def stats_fixed_norm(runs, bins, mins, maxs, name="mean", aggregator=np.nanmean):
    methods = natsort(set(run["method"] for run in runs))
    seeds = natsort(set(run["seed"] for run in runs))
    lengths = {}
    for run in runs:
        lengths[run["task"]] = max(lengths.get(run["task"], 0), max(run["xs"]))
    if bins <= 0:
        borders = {
            task: np.linspace(0, length + 1e-8, 30) for task, length in lengths.items()
        }
    else:
        border = np.arange(0, max(lengths.values()) + 1e-8, bins)
        borders = {task: border for task, length in lengths.items()}
    extra_runs = []
    for method in methods:
        for seed in seeds:
            scores = []
            for run in runs:
                if not (run["method"] == method and run["seed"] == seed):
                    continue
                task = run["task"]
                _, ys = binning(
                    run["xs"], run["ys"], borders[task], np.nanmean, fill="last"
                )
                if task == "atari_jamesbond" and "atari_james_bond" in mins:
                    task = "atari_james_bond"
                scores.append((ys - mins[task]) / (maxs[task] - mins[task]))
            if scores:
                xs = np.linspace(0, 1, len(scores[0]))
                extra_runs.append(
                    {
                        "task": f"stats_{name}",
                        "method": method,
                        "seed": seed,
                        "xs": xs,
                        "ys": reduce(scores, aggregator, 0),
                    }
                )
    return extra_runs


def stats_num_tasks(runs, bins):
    methods = natsort(set(run["method"] for run in runs))
    seeds = natsort(set(run["seed"] for run in runs))
    lengths = {}
    for run in runs:
        lengths[run["task"]] = max(lengths.get(run["task"], 0), max(run["xs"]))
    if bins <= 0:
        borders = {
            task: np.linspace(0, length + 1e-8, 30) for task, length in lengths.items()
        }
    else:
        border = np.arange(0, max(lengths.values()) + 1e-8, bins)
        borders = {task: border for task, length in lengths.items()}
    extra_runs = []
    for method in methods:
        for seed in seeds:
            nonempty = []
            for run in runs:
                if not (run["method"] == method and run["seed"] == seed):
                    continue
                task = run["task"]
                _, ys = binning(
                    run["xs"], run["ys"], borders[task], np.nanmean, fill="nan"
                )
                nonempty.append(np.isfinite(ys))
            if nonempty:
                xs = np.linspace(0, 1, len(nonempty[0]))
                extra_runs.append(
                    {
                        "task": "stats_number_of_tasks",
                        "method": method,
                        "seed": seed,
                        "xs": xs,
                        "ys": np.sum(nonempty, 0),
                    }
                )
    return extra_runs


def load_metrics(
    directory,
    pattern,
    xaxis,
    yaxis,
    yaxis2,
    seed_prefix=None,
    method_prefix=None,
    tasks=(r".*",),
    methods=(r".*",),
    workers=1,
):
    console = rich.console.Console()
    directory = directory.expanduser().resolve()
    tasks = [re.compile(regex) for regex in tasks]
    methods = [re.compile(regex) for regex in methods]
    runs = []
    for filename in directory.glob(pattern):
        task, method, seed = filename.parts[-4:-1]
        if not any(p.search(task) for p in tasks):
            continue
        if not any(p.search(method) for p in methods):
            continue
        if seed_prefix:
            seed = f"{seed_prefix}_{seed}"
        if method_prefix:
            method = f"{method_prefix}_{method}"
        runs.append(
            {"task": task, "method": method, "seed": seed, "filename": filename}
        )
    console.print(f"Loading {len(runs)} runs from [green]{directory}[/green]...")
    jobs = [functools.partial(load_run, run, xaxis, yaxis, yaxis2) for run in runs]
    if workers > 1:
        with mp.Pool(workers) as pool:
            promises = [pool.apply_async(j) for j in jobs]
            runs = [promise.get() for promise in tqdm.tqdm(promises)]
    else:
        runs = [job() for job in tqdm.tqdm(jobs)]
    runs = [r for r in runs if r is not None]
    return runs


def load_run(run, xaxis, yaxis, yaxis2):
    try:
        console = rich.console.Console()
        filename = run.pop("filename")
        try:
            df = pd.read_json(filename, lines=True)
        except ValueError:
            records = []
            for i, line in enumerate(pathlib.Path(filename).read_text().split("\n")):
                if not line:
                    continue
                try:
                    records.append(json.loads(line))
                except ValueError:
                    print(f"Skipping invalid JSON line {i} in {filename}.")
            df = pd.DataFrame(records)
        yaxis = yaxis if yaxis in df.columns else yaxis2
        df = df[[xaxis, yaxis]].dropna()
        run["xs"] = df[xaxis].to_numpy()
        run["ys"] = df[yaxis].to_numpy()
        return run
    except Exception as e:
        console.print(
            f'Exception loading {run["method"]} on {run["task"]}:\n {e}', style="red"
        )
        return None


def plots(amount, cols=4, size=(2, 2.3), xticks=4, yticks=5, grid=(1, 1), **kwargs):
    cols = min(cols, amount)
    rows = int(np.ceil(amount / cols))
    size = (cols * size[0], rows * size[1])
    fig, axes = plt.subplots(rows, cols, figsize=size, squeeze=False, **kwargs)
    axes = axes.flatten()
    for ax in axes:
        ax.xaxis.set_major_locator(ticker.MaxNLocator(xticks))
        ax.yaxis.set_major_locator(ticker.MaxNLocator(yticks))
        if grid:
            grid = (grid, grid) if not hasattr(grid, "__len__") else grid
            ax.grid(which="both", color="#eeeeee")
            ax.xaxis.set_minor_locator(ticker.AutoMinorLocator(int(grid[0])))
            ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(int(grid[1])))
            ax.tick_params(which="minor", length=0)
    for ax in axes[amount:]:
        ax.axis("off")
    axes = axes[:amount]
    return fig, axes


def curve(ax, xs, ys, low=None, high=None, label=None, order=0, **kwargs):
    finite = np.isfinite(ys)
    ax.plot(xs[finite], ys[finite], label=label, zorder=1000 - order, **kwargs)
    if low is not None and finite.sum() > 1:
        ax.fill_between(
            xs[finite],
            low[finite],
            high[finite],
            zorder=100 - order,
            alpha=0.2,
            lw=0,
            **kwargs,
        )


def legend(fig, mapping=None, adjust=False, **kwargs):
    options = dict(
        fontsize="medium",
        numpoints=1,
        labelspacing=0,
        columnspacing=1.2,
        handlelength=1.5,
        handletextpad=0.5,
        ncol=4,
        loc="lower center",
    )
    options.update(kwargs)
    entries = {}
    for ax in fig.axes:
        for handle, label in zip(*ax.get_legend_handles_labels()):
            if mapping and label in mapping:
                label = mapping[label]
            entries[label] = handle
    leg = fig.legend(entries.values(), entries.keys(), **options)
    leg.get_frame().set_edgecolor("white")
    if adjust is not False:
        pad = adjust if isinstance(adjust, (int, float)) else 0.5
        extent = leg.get_window_extent(fig.canvas.get_renderer())
        extent = extent.transformed(fig.transFigure.inverted())
        yloc, xloc = options["loc"].split()
        y0 = dict(lower=extent.y1, center=0, upper=0)[yloc]
        y1 = dict(lower=1, center=1, upper=extent.y0)[yloc]
        x0 = dict(left=extent.x1, center=0, right=0)[xloc]
        x1 = dict(left=1, center=1, right=extent.x0)[xloc]
        fig.tight_layout(rect=[x0, y0, x1, y1], h_pad=pad, w_pad=pad)


def smart_format(x, pos=None):
    if abs(x) < 1e3:
        if float(int(x)) == float(x):
            return str(int(x))
        return str(round(x, 10)).rstrip("0")
    if abs(x) < 1e6:
        return f"{x/1e3:.0f}K" if x == x // 1e3 * 1e3 else f"{x/1e3:.1f}K"
    if abs(x) < 1e9:
        return f"{x/1e6:.0f}M" if x == x // 1e6 * 1e6 else f"{x/1e6:.1f}M"
    return f"{x/1e9:.0f}B" if x == x // 1e9 * 1e9 else f"{x/1e9:.1f}B"


def save(fig, filename):
    console = rich.console.Console()
    filename = pathlib.Path(filename).expanduser()
    filename.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(filename)
    console.print(f"Saved [green]{filename}[/green]")
    if filename.suffix == ".pdf":
        try:
            subprocess.call(["pdfcrop", str(filename), str(filename)])
        except FileNotFoundError:
            print("Install LaTeX to crop PDF outputs.")


def binning(xs, ys, borders, reducer=np.nanmean, fill="nan"):
    xs = xs if isinstance(xs, np.ndarray) else np.array(xs)
    ys = ys if isinstance(ys, np.ndarray) else np.array(ys)
    order = np.argsort(xs)
    xs, ys = xs[order], ys[order]
    binned = []
    for start, stop in zip(borders[:-1], borders[1:]):
        left = (xs <= start).sum()
        right = (xs <= stop).sum()
        if left < right:
            value = reduce(ys[left:right], reducer)
        elif binned:
            value = {"nan": np.nan, "last": binned[-1]}[fill]
        else:
            value = np.nan
        binned.append(value)
    return borders[1:], np.array(binned)


def reduce(values, reducer=np.nanmean, *args, **kwargs):
    with warnings.catch_warnings():  # Buckets can be empty.
        warnings.simplefilter("ignore", category=RuntimeWarning)
        return reducer(values, *args, **kwargs)


def natsort(sequence):
    pattern = re.compile(r"([0-9]+)")
    return sorted(
        sequence,
        key=lambda x: [(int(y) if y.isdigit() else y) for y in pattern.split(x)],
    )


def parse_args(argv=None):
    boolean = lambda x: bool(["False", "True"].index(x))
    parser = argparse.ArgumentParser()
    parser.add_argument("--indirs", nargs="+", type=pathlib.Path, required=True)
    parser.add_argument("--outdir", type=pathlib.Path, required=True)
    parser.add_argument("--pattern", type=str, default="**/scores.jsonl")
    parser.add_argument("--prefix", type=boolean, default=False)
    parser.add_argument("--xaxis", type=str, default="step")
    parser.add_argument("--yaxis", type=str, default="episode/score")
    parser.add_argument("--yaxis2", type=str, default="eval_episode/score")
    parser.add_argument("--tasks", nargs="+", default=[r".*"])
    parser.add_argument("--methods", nargs="+", default=[r".*"])
    parser.add_argument("--bins", type=float, default=-1)
    parser.add_argument("--agg", type=boolean, default=True)
    parser.add_argument("--size", nargs=2, type=float, default=[2.5, 2.3])
    parser.add_argument("--cols", type=int, default=6)
    parser.add_argument("--legendcols", type=int, default=0)
    parser.add_argument("--xlim", nargs=2, type=float, default=None)
    parser.add_argument("--ylim", nargs=2, type=float, default=None)
    parser.add_argument("--xticks", nargs="+", type=float, default=None)
    parser.add_argument("--labels", nargs="+", default=[])
    parser.add_argument("--colors", type=str, nargs="+", default=["contrast"])
    parser.add_argument("--workers", type=int, default=12)
    parser.add_argument(
        "--stats", type=str, nargs="*", default=["mean", "median", "tasks"]
    )
    args = parser.parse_args(argv)
    args.indirs = tuple([x.expanduser() for x in args.indirs])
    args.outdir = args.outdir.expanduser() / args.indirs[0].stem
    assert len(args.labels) % 2 == 0
    args.labels = {k: v for k, v in zip(args.labels[:-1], args.labels[1:])}
    if len(args.colors) == 1:
        try:
            args.colors = plt.get_cmap(args.colors[0])
        except ValueError:
            if args.colors[0] in COLORS:
                cmap = COLORS[args.colors[0]]
            else:
                cmap = args.colors
            args.colors = lambda i: cmap[i % len(cmap)]
    if args.stats == ["none"]:
        args.stats = []
    return args


if __name__ == "__main__":
    main()
