import os
import glob
import json
import pickle
from collections import OrderedDict

from common_plot import get_loss_stats, plot_lr_sensitivity, plot_loss_curves


ROOT = "."
RESULTS_DIR = os.path.join(ROOT, "gptopt/outputs/54_fineweb1B_final_comparison")
RUN_NAME = os.path.basename(RESULTS_DIR)[3:]

algs = ["muon", "scion", "muon-momo-stale", "muonmax-momo-stale"]
muon_lrs = [0.0003, 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1.0]
seeds = [42, 43, 44]
lr_ratios = {
    "muon": 10,
    "scion": 10,
    "muon-momo-stale": 10,
    "muonmax-momo-stale": 1,
}


# Read results.
total_results = OrderedDict()
for alg in algs:
    for muon_lr in muon_lrs:
        for seed in seeds:
            result_dir = os.path.join(RESULTS_DIR, f"{alg}_{muon_lr}_{seed}")
            results_paths = glob.glob(os.path.join(result_dir, "*.json"))
            assert len(results_paths) == 1
            results_path = results_paths[0]

            with open(results_path, "r") as results_file:
                total_results[(alg, muon_lr, seed)] = json.load(results_file)

# Get loss statistics.
final_losses, final_loss_stats, tuned_final_losses, smoothed_loss_stats = get_loss_stats(total_results)

# Print table of best losses.
for alg in algs:
    mean = tuned_final_losses[alg]["mean"]
    std = tuned_final_losses[alg]["std"]
    muon_lr = tuned_final_losses[alg]["best_hp"]
    other_lr = muon_lr / lr_ratios[alg]
    print(f"{alg:22}    {mean:.4f} +/- {std:.4f}    (muon_lr={muon_lr}, other_lr={other_lr})")

# Make LR sensitivity plot.
lr_sens_path = os.path.join(RESULTS_DIR, f"{RUN_NAME}_lr_sensitivity.pdf")
ylim = [3.5, 4.0]
plot_lr_sensitivity(algs, final_loss_stats, lr_sens_path, ylim=ylim)

# Plot loss curves for tuned algs.
loss_curves_path = os.path.join(RESULTS_DIR, f"{RUN_NAME}_loss_curves.pdf")
xlim = [0.6, 1.0]
ylim = [3.5, 3.8]
plot_loss_curves(algs, smoothed_loss_stats, tuned_final_losses, loss_curves_path, xlim=xlim, ylim=ylim)
