import json

from overtraining.plotting.shared import *

data = {}
with open("exp_data/grid.json", "r") as f:
    data = json.load(f)


configs = set()


for e in sorted(data):
    arch, bs, lr, wd, _, D, warm, _, _ = e.split("-")
    d, l, h = arch.split("_")
    d = d.split("=")[-1]
    l = l.split("=")[-1]
    h = h.split("=")[-1]
    N = int(D) / 20
    key = (int(l), int(h), int(d), f"{int(N) / 1_000_000_000:.3f}")
    configs.add(key)


configs = sorted(list(configs), key=lambda tup: tup[-1])
str_configs = []
for c in configs:
    str_configs.append([str(e) for e in c])

cols = [
    r"$n_{layers}$",
    r"$n_{heads}$",
    r"$d_{model}$",
]

print(make_latex_table(cols, str_configs, "None", "None", "None"))
print(len(configs))
