import os
import glob
import json
import pickle
from collections import OrderedDict

from common_plot import get_loss_stats, plot_lr_sensitivity, plot_loss_curves


ROOT = "."
RESULTS_DIR = os.path.join(ROOT, "gptopt/outputs/55_fineweb1B_lb_ablation")
RUN_NAME = "fineweb1B_lb_ablation"

algs = ["muon-momo", "muonmax-momo"]
muon_lrs = [0.001, 0.01, 0.1, 1.0]
loss_lbs = [0.0, 1.6, 2.4, 2.8, 3.2]
seed = 42

# Read results.
total_results = OrderedDict()
for alg in algs:
    for loss_lb in loss_lbs:
        for muon_lr in muon_lrs:

            # Read results without truncation.
            result_dir = os.path.join(RESULTS_DIR, f"{alg}_{loss_lb}_{muon_lr}")
            results_paths = glob.glob(os.path.join(result_dir, "*.json"))
            assert len(results_paths) == 1
            results_path = results_paths[0]
            with open(results_path, "r") as results_file:
                total_results[(f"{alg}-{loss_lb}", muon_lr, seed)] = json.load(results_file)

# Get loss statistics.
final_losses, final_loss_stats, tuned_final_losses, smoothed_loss_stats = get_loss_stats(total_results)

# Print table of best losses.
for alg in tuned_final_losses:
    mean = tuned_final_losses[alg]["mean"]
    std = tuned_final_losses[alg]["std"]
    muon_lr = tuned_final_losses[alg]["best_hp"]
    print(f"{alg:12}    {mean:.4f} +/- {std:.4f}    (muon_lr={muon_lr})")

# Make LR sensitivity plot.
ylim = [3.5, 4.0]
for alg in algs:
    current_results = {key: total_results[key] for key in total_results.keys() if key[0][:-4] == alg}
    lr_sens_path = os.path.join(RESULTS_DIR, f"{RUN_NAME}_{alg}_lr_sensitivity.pdf")
    current_algs = []
    for key in current_results:
        if key[0] not in current_algs:
            current_algs.append(key[0])
    plot_lr_sensitivity(current_algs, final_loss_stats, lr_sens_path, ylim=ylim)
