import os
import glob
import json
import pickle
from collections import OrderedDict

from common_plot import get_loss_stats, plot_lr_sensitivity, plot_loss_curves


ROOT = "."
RESULTS_DIR = os.path.join(ROOT, "gptopt/outputs/63_slim_pajama6B")
RUN_NAME = "slim_pajama6B"

algs = ["muon", "muonmax-momo-stale"]
muon_lrs = [0.0001, 0.001, 0.01, 0.1]

lr_ratios = {
    "muon": 10,
    "muonmax-momo-stale": 1,
}
seed = 42


# Read results.
total_results = OrderedDict()
for alg in algs:
    for muon_lr in muon_lrs:
        result_dir = os.path.join(RESULTS_DIR, f"{alg}_{muon_lr}")
        results_paths = glob.glob(os.path.join(result_dir, "*.json"))
        assert len(results_paths) == 1
        results_path = results_paths[0]

        with open(results_path, "r") as results_file:
            total_results[(alg, muon_lr, seed)] = json.load(results_file)

# Get loss statistics.
final_losses, final_loss_stats, tuned_final_losses, smoothed_loss_stats = get_loss_stats(total_results)

# Print table of best losses.
for alg in algs:
    mean = tuned_final_losses[alg]["mean"]
    std = tuned_final_losses[alg]["std"]
    muon_lr = tuned_final_losses[alg]["best_hp"]
    other_lr = muon_lr / lr_ratios[alg]
    print(f"{alg:22}    {mean:.4f} +/- {std:.4f}    (muon_lr={muon_lr}, other_lr={other_lr})")

# Make LR sensitivity plot.
lr_sens_path = os.path.join(RESULTS_DIR, f"{RUN_NAME}_lr_sensitivity.pdf")
ylim = [2.5, 4.0]
plot_lr_sensitivity(algs, final_loss_stats, lr_sens_path, ylim=ylim)

# Plot loss curves for tuned algs.
loss_curves_path = os.path.join(RESULTS_DIR, f"{RUN_NAME}_loss_curves.pdf")
xlim = [0.6, 1.0]
ylim = [2.5, 3.0]
plot_loss_curves(algs, smoothed_loss_stats, tuned_final_losses, loss_curves_path, xlim=xlim, ylim=ylim)
