#!/usr/bin/env python3

import argparse
import json
import math
import numpy as np

import latency_env.misc.argparser_types as at
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList
from latency_env.misc import ArgumentParser

def parse_input_json(s : str):
    """
    We receive input specs as algname:environment:path
    """
    idx = s.find(":")
    if idx < 0:
        raise ValueError("Missing first colon")
    (algname, s) = (s[:idx], s[idx+1:])

    idx = s.find(":")
    if idx < 0:
        raise ValueError("Missing second colon")
    (envname, path) = (s[:idx], s[idx+1:])

    with open(path) as f:
        data = json.load(f)

    return (algname, envname, data)

parser = ArgumentParser()
parser += ArgList(
    jsondata = Arg(metavar="algname:envname:path", type=parse_input_json, nargs="+",
                   help="Input JSON file(s) to get data from. "
                        "Example format: SAC:Ant-v4:_logs/run1234/eval.json"),
    plot_std = Arg("--show-std", action=argparse.BooleanOptionalAction, default=False,
                   help="Show standard deviations."),
    indentation = Arg("--indent", type=at.nonnegint, default=0,
                      help="Indentation of output."),
)
args = parser.parse_args()



algs = []
envs = []
alg_lookup = {}
for (algname, envname, data) in args.jsondata:
    if algname not in algs:
        algs.append(algname)
        alg_lookup[algname] = {}
    if envname not in envs:
        envs.append(envname)

    alg_lookup[algname][envname] = {"data": data}

    best_mean = None
    best_std = None
    latencies = set()
    for e in data["evaluations"]:
        for latency, ed in e["latency_eval"].items():
            latencies.add(latency)
            # e["iterations"]
            # ed["returns"]
            # ed["lengths"]
            if "returns" in ed:
                ret = np.array(ed["returns"], dtype=np.float64)
                if best_mean is None or (best_mean < ret.mean()):
                    best_mean = ret.mean()
                    best_std = ret.std()
            elif "return_avg" in ed:
                # Special case for dreamer
                ret_mean = ed["return_avg"]
                ret_std = ed["return_std"]
                if best_mean is None or (best_mean < ret_mean):
                    best_mean = ret_mean
                    best_std = ret_std
            else:
                raise NotImplementedError(f"Unsupported {e}")
    assert len(latencies) == 1

    alg_lookup[algname][envname] |= {
        "best_mean": best_mean,
        "best_std": best_std,
    }


lines = []
lines += [
    r"\begin{adjustbox}{max width=\textwidth}",
    r"\begin{tblr}{",
    r"  colspec = {",
    r"      Q[l,t]",
]
lines += [r"     |Q[r,t]Q[c,t]Q[r,t]"] * len(envs)
lines += [
    r"  },",
    r"  colsep = 6pt,",
    rf"  cell{{1}}{{{','.join([str((i * 3) + 2) for i in range(len(envs))])}}} = {{c=3}}{{c}},",
    rf"  column{{{','.join([str((i * 3) + 2) for i in range(len(envs))])}}} = {{rightsep=0pt}},",
    rf"  column{{{','.join([str((i * 3) + 3) for i in range(len(envs))])}}} = {{colsep=3pt}},",
    rf"  column{{{','.join([str((i * 3) + 4) for i in range(len(envs))])}}} = {{leftsep=0pt}},",
    r"}",
    r"\hline",
]

ENVNAME_CONV = {
    "ant": "Ant-v4",
    "humanoid": "Humanoid-v4",
    "halfcheetah": "HalfCheetah-v4",
    "hopper": "Hopper-v4",
    "walker2d": "Walker2d-v4",
}
for env in envs:
    lines += [r"& \texttt{" + ENVNAME_CONV.get(env, env) + "} &&"]
lines += [r"\\\hline"]

for i, alg in enumerate(algs):
    is_mbpac = bool(alg.lower().strip() in ["mbpac"])
    if is_mbpac:
        lines += [r"\SetRow{bg=black!10}"]
    lines += [alg]
    for env in envs:
        d = alg_lookup[alg][env]

        alg_mean, alg_std = (d["best_mean"], d["best_std"])
        is_best_alg = all(alg_lookup[alg2][env]["best_mean"] <= alg_mean for alg2 in algs)

        if is_best_alg:
            lines += [rf"& $\bm{{{alg_mean:.2f}}}$ & $\bm{{\pm}}$ & $\bm{{{alg_std:.2f}}}$ % {ENVNAME_CONV.get(env, env)}"]
        else:
            lines += [rf"& ${alg_mean:.2f}$ & $\pm$ & ${alg_std:.2f}$ % {ENVNAME_CONV.get(env, env)}"]

    if i < (len(algs) - 1):
        lines += [r"\\\hline[dotted,fg=black!50]"]

lines += [
    r"\end{tblr}",
    r"\end{adjustbox}",
]

print("\n".join([
    (" " * args.indentation) + l for l in lines
]))






