#!/usr/bin/env python3

import argparse
import json
import math
import numpy as np
import os

import latency_env.misc.argparser_types as at
from latency_env.misc import Argument as Arg
from latency_env.misc import ArgumentList as ArgList
from latency_env.misc import ArgumentParser


def parse_input_suffix(s : str):
    """
    We receive input specs as algname:suffix
    """
    idx = s.find(":")
    if idx < 0:
        raise ValueError("Missing first colon")
    return (s[:idx], s[idx+1:])


parser = ArgumentParser()
parser += ArgList(
    bm_prefix = Arg("--bm-prefix", metavar="path", type=str, required=True,
                            help="Example: benchmark-v3-tc4"),
    delay_tag = Arg("--tag", metavar="tag", type=str, required=True),
    env_suffix = Arg("--env-suffix", metavar="suffix", type=str, required=True),
    noise_percs = Arg("--noise-percs", type=at.nonnegint_commalist,
                     required=True, help="Example: --noise-percs 0,5,10,15"),
    indentation = Arg("--indent", type=at.nonnegint, default=0,
                      help="Indentation of output."),
    alg_suffixes = Arg(metavar="suffix", type=parse_input_suffix, nargs="+",
                       help="Example: BPQL:bpql-mem24"),
)
args = parser.parse_args()


lookup = {}
for noise in args.noise_percs:
    for (algname, algsuffix) in args.alg_suffixes:
        if noise == 5:
            # Use the regular "noisy results"
            json_path = os.path.join(
                "_logs",
                args.bm_prefix,
                f"noisy-{args.env_suffix}",
                f"{args.delay_tag}-{algsuffix}",
                "eval.json",
            )
        else:
            json_path = os.path.join(
                "_logs/noise-eval",
                args.bm_prefix,
                f"noiseperc{noise}-{args.env_suffix}",
                f"{args.delay_tag}-{algsuffix}",
                "eval.json",
            )
        with open(json_path) as f:
            data = json.load(f)

        best_mean = None
        best_std = None
        latencies = set()
        for e in data["evaluations"]:
            for latency, ed in e["latency_eval"].items():
                latencies.add(latency)
                # e["iterations"]
                # ed["returns"]
                # ed["lengths"]
                if "returns" in ed:
                    ret = np.array(ed["returns"], dtype=np.float64)
                    if best_mean is None or (best_mean < ret.mean()):
                        best_mean = ret.mean()
                        best_std = ret.std()
                elif "return_avg" in ed:
                    # Special case for dreamer
                    ret_mean = ed["return_avg"]
                    ret_std = ed["return_std"]
                    if best_mean is None or (best_mean < ret_mean):
                        best_mean = ret_mean
                        best_std = ret_std
                else:
                    raise NotImplementedError(f"Unsupported {e}")
        assert len(latencies) == 1

        if noise not in lookup:
            lookup[noise] = {}

        assert algname not in lookup[noise]
        lookup[noise][algname] = {
            "best_mean": best_mean,
            "best_std": best_std,
        }


lines = []
lines += [
    r"\begin{adjustbox}{max width=\textwidth}",
    r"\begin{tblr}{",
    r"  colspec = {",
    r"      Q[l,t]",
]
lines += [r"     |Q[r,t]Q[c,t]Q[r,t]"] * len(args.noise_percs)
lines += [
    r"  },",
    r"  colsep = 6pt,",
    rf"  cell{{1}}{{{','.join([str((i * 3) + 2) for i in range(len(args.noise_percs))])}}} = {{c=3}}{{c}},",
    rf"  column{{{','.join([str((i * 3) + 2) for i in range(len(args.noise_percs))])}}} = {{rightsep=0pt}},",
    rf"  column{{{','.join([str((i * 3) + 3) for i in range(len(args.noise_percs))])}}} = {{colsep=3pt}},",
    rf"  column{{{','.join([str((i * 3) + 4) for i in range(len(args.noise_percs))])}}} = {{leftsep=0pt}},",
    r"}",
    r"\hline",
]

for noise in args.noise_percs:
    lines += [f"& {noise}\\% noise &&"]
lines += [r"\\\hline"]

for i, (alg, _) in enumerate(args.alg_suffixes):
    is_acda = bool(alg.lower().strip() in ["acda"])
    if is_acda:
        lines += [r"\SetRow{bg=black!10}"]
    lines += [alg]
    for noise in args.noise_percs:
        d = lookup[noise][alg]

        alg_mean, alg_std = (d["best_mean"], d["best_std"])
        is_best_alg = all(lookup[noise][alg2]["best_mean"] <= alg_mean for (alg2, _) in args.alg_suffixes)

        if is_best_alg:
            lines += [rf"& $\bm{{{alg_mean:.2f}}}$ & $\bm{{\pm}}$ & $\bm{{{alg_std:.2f}}}$ % {noise}% noise"]
        else:
            lines += [rf"& ${alg_mean:.2f}$ & $\pm$ & ${alg_std:.2f}$ % {noise}% noise"]

    if i < (len(args.noise_percs) - 1):
        lines += [r"\\\hline[dotted,fg=black!50]"]

lines += [
    r"\end{tblr}",
    r"\end{adjustbox}",
]

print("\n".join([
    (" " * args.indentation) + l for l in lines
]))






