import os
import glob
import json
import pickle
from collections import OrderedDict

from common_plot import get_loss_stats, plot_lr_sensitivity, plot_loss_curves


ROOT = "."
TRUNC_DIR = os.path.join(ROOT, "gptopt/outputs/49_truncated_scale_tune")
VANILLA_DIR = os.path.join(ROOT, "gptopt/outputs/53_no_truncation_scale_sweep")
EXTRA_DIR = os.path.join(ROOT, "gptopt/outputs/58_truncation_extra_lrs")
RUN_NAME = "fineweb1B_trunc_comparison"

algs = ["muon", "scion", "muonmax", "muon-momo", "scion-momo", "muonmax-momo"]
alg_settings = {
    "muon": {
        "optimizer": "nesgd-adam_infty-lmo",
        "tuned_lrs": {
            "muon": 0.01,
            "other": 1e-3,
        },
    },
    "scion": {
        "optimizer": "nesgd-lmo",
        "tuned_lrs": {
            "muon": 0.01,
            "other": 1e-3,
        },
    },
    "muonmax": {
        "optimizer": "nesgd-hybrid_prod-adam_2",
        "tuned_lrs": {
            "muon": 0.001,
            "other": 0.001,
        },
    },
}
scales = [0.3, 1.0, 3.0, 10.0, 30.0, 100.0]
seed = 42

# Read results.
total_results = OrderedDict()
for alg, settings in alg_settings.items():
    for scale in scales:

        optimizer = settings["optimizer"]
        muon_lr = scale * settings["tuned_lrs"]["muon"]

        # Read results without truncation.
        result_dir = os.path.join(VANILLA_DIR, f"{optimizer}_{scale}")
        results_paths = glob.glob(os.path.join(result_dir, "*.json"))
        assert len(results_paths) == 1
        results_path = results_paths[0]
        with open(results_path, "r") as results_file:
            total_results[(alg, muon_lr, seed)] = json.load(results_file)

        # Read results with truncation.
        result_dir = os.path.join(TRUNC_DIR, f"{optimizer}_{scale}")
        results_paths = glob.glob(os.path.join(result_dir, "*.json"))
        assert len(results_paths) == 1
        results_path = results_paths[0]
        with open(results_path, "r") as results_file:
            total_results[(alg + "-momo", muon_lr, seed)] = json.load(results_file)

# Read additional results with extended LR ranges.
for result_dir in os.listdir(EXTRA_DIR):
    under_pos = result_dir.find("_")
    assert under_pos != -1
    alg = result_dir[:under_pos]
    muon_lr = float(result_dir[under_pos+1:])
    assert alg in alg_settings or (alg[-5:] == "-momo" and alg[:-5] in alg_settings)
    results_paths = glob.glob(os.path.join(EXTRA_DIR, result_dir, "*.json"))
    assert len(results_paths) == 1
    results_path = results_paths[0]
    with open(results_path, "r") as results_file:
        total_results[(alg, muon_lr, seed)] = json.load(results_file)

# Get loss statistics.
final_losses, final_loss_stats, tuned_final_losses, smoothed_loss_stats = get_loss_stats(total_results)

# Print table of best losses.
for alg in algs:
    mean = tuned_final_losses[alg]["mean"]
    std = tuned_final_losses[alg]["std"]
    muon_lr = tuned_final_losses[alg]["best_hp"]
    print(f"{alg:12}    {mean:.4f} +/- {std:.4f}    (muon_lr={muon_lr})")

# Make LR sensitivity plot.
lr_sens_path = os.path.join(VANILLA_DIR, f"{RUN_NAME}_lr_sensitivity.pdf")
ylim = [3.5, 4.0]
plot_lr_sensitivity(algs, final_loss_stats, lr_sens_path, ylim=ylim)

# Plot loss curves for tuned algs.
loss_curves_path = os.path.join(VANILLA_DIR, f"{RUN_NAME}_loss_curves.pdf")
xlim = [0.6, 1.0]
ylim = [3.5, 3.8]
plot_loss_curves(algs, smoothed_loss_stats, tuned_final_losses, loss_curves_path, xlim=xlim, ylim=ylim)
