#!/usr/bin/env python3

import argparse
import json
import math
import numpy as np

import latency_env.misc.argparser_types as at
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList
from latency_env.misc import ArgumentParser

def parse_input_json(s : str):
    """
    We receive input specs as algname:environment:path
    """
    idx = s.find(":")
    if idx < 0:
        raise ValueError("Missing first colon")
    (algname, s) = (s[:idx], s[idx+1:])

    idx = s.find(":")
    if idx < 0:
        raise ValueError("Missing second colon")
    (envname, s) = (s[:idx], s[idx+1:])

    idx = s.find(":")
    if idx < 0:
        raise ValueError("Missing third colon")
    (delayname, path) = (s[:idx], s[idx+1:])

    with open(path) as f:
        data = json.load(f)

    return (algname, envname, delayname, data)

parser = ArgumentParser()
parser += ArgList(
    jsondata = Arg(metavar="algname:envname:delayname:path", type=parse_input_json, nargs="+",
                   help="Input JSON file(s) to get data from. "
                        "Example format: SAC:Ant-v4:MM1:_logs/run1234/eval.json"),
    indentation = Arg("--indent", type=at.nonnegint, default=0,
                      help="Indentation of output."),
)
args = parser.parse_args()



algs = []
envs = []
delays = []
alg_lookup = {}
for (algname, envname, delayname, data) in args.jsondata:
    if algname not in algs:
        algs.append(algname)
    if envname not in envs:
        envs.append(envname)
    if delayname not in delays:
        delays.append(delayname)

    if algname not in alg_lookup:
        alg_lookup[algname] = {}
    if envname not in alg_lookup[algname]:
        alg_lookup[algname][envname] = {}

    alg_lookup[algname][envname][delayname] = {"data": data}

    best_mean = None
    best_std = None
    latencies = set()
    for e in data["evaluations"]:
        for latency, ed in e["latency_eval"].items():
            latencies.add(latency)
            # e["iterations"]
            # ed["returns"]
            # ed["lengths"]
            if "returns" in ed:
                ret = np.array(ed["returns"], dtype=np.float64)
                if best_mean is None or (best_mean < ret.mean()):
                    best_mean = ret.mean()
                    best_std = ret.std()
            elif "return_avg" in ed:
                # Special case for dreamer
                ret_mean = ed["return_avg"]
                ret_std = ed["return_std"]
                if best_mean is None or (best_mean < ret_mean):
                    best_mean = ret_mean
                    best_std = ret_std
            else:
                raise NotImplementedError(f"Unsupported {e}")
    assert len(latencies) == 1

    alg_lookup[algname][envname][delayname] |= {
        "best_mean": best_mean,
        "best_std": best_std,
    }


lines = []
lines += [
    r"\begin{adjustbox}{max width=\textwidth}",
    r"\begin{tblr}{",
    r"  colspec = {",
    r"      Q[l,t]",
]
lines += [r"     |" + r"|[fg=black!50]".join(["Q[r,t]"] * len(delays))] * len(envs)
lines += [
    r"  },",
    r"  colsep = 3pt,",
    r"}",
    r"\hline",
]

ENVNAME_CONV = {
    "ant": "Ant-v4",
    "humanoid": "Humanoid-v4",
    "halfcheetah": "HalfCheetah-v4",
    "hopper": "Hopper-v4",
    "walker2d": "Walker2d-v4",
}
lines += [r"\small\emph{Gymnasium env.}"]
for env in envs:
    lines += [rf"& \SetCell[c={len(delays)}]{{c}} \texttt{{{ENVNAME_CONV.get(env, env)}}} " + "&"*(len(delays)-1)]
lines += [r"\\\hline"]

lines += [r"\small\emph{Delay process}"]
for _ in envs:
    lines += ["".join([r"& \SetCell[c=1]{c} {\small$\text{" + delay + "}$}" for delay in delays])]
lines += [r"\\\hline"]

for i, alg in enumerate(algs):
    is_mbpac = bool("mbpac" in alg.lower().strip())
    if is_mbpac:
        lines += [r"\SetRow{bg=black!10}"]
    lines += [alg]
    for env in envs:
        parts = []
        for delay in delays:
            d = alg_lookup[alg][env][delay]

            alg_mean, alg_std = (d["best_mean"], d["best_std"])
            is_best_alg = all(alg_lookup[alg2][env][delay]["best_mean"] <= alg_mean for alg2 in algs)

            if is_best_alg:
                parts.append(rf"& $\bm{{{alg_mean:.2f}}}$")
            else:
                parts.append(rf"& ${alg_mean:.2f}$")

        lines += [" ".join(parts) + f" % {ENVNAME_CONV.get(env, env)}"]

    if i < (len(algs) - 1):
        lines += [r"\\\hline[dotted,fg=black!50]"]

lines += [
    r"\end{tblr}",
    r"\end{adjustbox}",
]

print("\n".join([
    (" " * args.indentation) + l for l in lines
]))






