import os
import json
import glob

import numpy as np


ROOT = "."
RESULTS_DIR = os.path.join(ROOT, "gptopt/outputs/59_wallclock_test")

base_alg = "muon"
algs = {
    "muon": "nesgd-adam_infty-lmo",
    "muonmax": "nesgd-adam_2-hybrid_prod",
    "muonmax-stale": "nesgd-adam_2-hybrid_prod-stale",
    "muon-momo": "nesgd-adam_infty-lmo-momo",
    "muon-momo-stale": "nesgd-adam_infty-lmo-momo-stale",
    "scion-momo": "nesgd-lmo-momo",
    "scion-momo-stale": "nesgd-lmo-momo-stale",
    "muonmax-momo": "nesgd-adam_2-hybrid_prod-momo",
    "muonmax-momo-stale": "nesgd-adam_2-hybrid_prod-momo-stale",
}

step_times = {}
for alg, optimizer in algs.items():
    result_paths = glob.glob(os.path.join(RESULTS_DIR, f"{optimizer}-lr-*.json"))

    # temp:
    if len(result_paths) == 0:
        print(optimizer)
        continue

    assert len(result_paths) == 1
    result_path = result_paths[0]
    with open(result_path, "r") as result_file:
        results = json.load(result_file)
    step_times[alg] = np.median(results["step_times"])

for alg in step_times:
    relative_step_time = step_times[alg] / step_times[base_alg]
    print(f"{alg:22}: {relative_step_time:.5f}")
