import os
import glob
import json
import pickle
from collections import OrderedDict

from common_plot import get_loss_stats, plot_lr_sensitivity, plot_loss_curves


ROOT = "."
RESULTS_DIR = os.path.join(ROOT, "gptopt/outputs/61_slim_pajama1B_first")
RUN_NAME = "slim_pajama1B_sweep"

algs = ["muon", "scion", "muon-momo-stale", "muonmax-momo-stale"]
muon_lrs = {
    "muon": [0.0001, 0.001, 0.01, 0.1],
    "scion": [0.0001, 0.001, 0.01, 0.1],
    "muon-momo-stale": [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0],
    "muonmax-lmo-momo-stale": [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0],
    "muonmax-momo-stale": [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0],
}

seeds = [42]
lr_ratios = {
    "muon": 10,
    "scion": 10,
    "muon-momo-stale": 10,
    "muonmax-lmo-momo-stale": 1,
    "muonmax-momo-stale": 1,
}


# Read results.
total_results = OrderedDict()
for alg in algs:
    for muon_lr in muon_lrs[alg]:
        for seed in seeds:
            # temp: only one seed for now
            #result_dir = os.path.join(RESULTS_DIR, f"{alg}_{muon_lr}_{seed}")
            result_dir = os.path.join(RESULTS_DIR, f"{alg}_{muon_lr}")
            results_paths = glob.glob(os.path.join(result_dir, "*.json"))
            assert len(results_paths) == 1
            results_path = results_paths[0]

            with open(results_path, "r") as results_file:
                total_results[(alg, muon_lr, seed)] = json.load(results_file)

# Get loss statistics.
final_losses, final_loss_stats, tuned_final_losses, smoothed_loss_stats = get_loss_stats(total_results)

# Print table of best losses.
for alg in algs:
    mean = tuned_final_losses[alg]["mean"]
    std = tuned_final_losses[alg]["std"]
    muon_lr = tuned_final_losses[alg]["best_hp"]
    other_lr = muon_lr / lr_ratios[alg]
    print(f"{alg:22}    {mean:.4f} +/- {std:.4f}    (muon_lr={muon_lr}, other_lr={other_lr})")

# Make LR sensitivity plot.
lr_sens_path = os.path.join(RESULTS_DIR, f"{RUN_NAME}_lr_sensitivity.pdf")
ylim = [3.1, 4.0]
plot_lr_sensitivity(algs, final_loss_stats, lr_sens_path, ylim=ylim)

# Plot loss curves for tuned algs.
loss_curves_path = os.path.join(RESULTS_DIR, f"{RUN_NAME}_loss_curves.pdf")
xlim = [0.6, 1.0]
ylim = [3.0, 3.6]
plot_loss_curves(algs, smoothed_loss_stats, tuned_final_losses, loss_curves_path, xlim=xlim, ylim=ylim)
