import os
import glob
import json
import pickle
from collections import OrderedDict

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

from common_plot import get_loss_stats, plot_lr_sensitivity, plot_loss_curves, DISPLAY_NAMES


ROOT = "."
RESULTS_DIR = os.path.join(ROOT, "gptopt/outputs/62_slim_pajama1B_double_grid")
RUN_NAME = "slim_pajama1B_double_grid"

algs = ["muon", "scion", "muon-momo-stale", "muonmax-momo-stale"]
muon_lrs = [0.0001, 0.001, 0.01, 0.1]
other_lrs = [0.00001, 0.0001, 0.001, 0.01]
seed = 42

# Read results.
total_results = OrderedDict()
for alg in algs:
    for muon_lr in muon_lrs:
        for other_lr in other_lrs:
            result_dir = os.path.join(RESULTS_DIR, f"{alg}_{muon_lr}_{other_lr}")
            results_paths = glob.glob(os.path.join(result_dir, "*.json"))
            assert len(results_paths) == 1
            results_path = results_paths[0]

            with open(results_path, "r") as results_file:
                total_results[(alg, (muon_lr, other_lr), seed)] = json.load(results_file)

# Get loss statistics.
final_losses, final_loss_stats, tuned_final_losses, smoothed_loss_stats = get_loss_stats(total_results)

# Print table of best losses.
print("Best loss for each alg with tuned LRs:")
for alg in algs:
    mean = tuned_final_losses[alg]["mean"]
    std = tuned_final_losses[alg]["std"]
    muon_lr, other_lr = tuned_final_losses[alg]["best_hp"]
    print(f"{alg:22}    {mean:.4f} +/- {std:.4f}    (muon_lr={muon_lr}, other_lr={other_lr})")

# Print all losses.
print("")
print("All losses for each alg:")
for alg in algs:
    print(alg)
    percent_under = np.mean([
        float(final_loss_stats[(alg, hp)]["mean"] < 3.7)
        for (a, hp) in final_loss_stats if a == alg
    ])
    print(f"Configs under 3.7: {percent_under}")
    print("")
    msg = "        "
    for muon_lr in muon_lrs:
        msg += f"        {muon_lr:13}"
    print(msg)
    for other_lr in other_lrs:
        msg = f"{other_lr:8}"
        for muon_lr in muon_lrs:
            hp = (muon_lr, other_lr)
            mean = final_loss_stats[(alg, hp)]["mean"]
            std = final_loss_stats[(alg, hp)]["std"]
            msg += f"    {mean:.4f} +/- {std:.4f}"
        print(msg)
    print("")
    print("----")
    print("")

# Make 2D LR sensitivity plots.
loss_grids = {
    alg: np.array([
        [final_loss_stats[(alg, (muon_lr, other_lr))]["mean"] for muon_lr in muon_lrs]
        for other_lr in other_lrs
    ])
    for alg in algs
}
min_loss = np.min([np.min(loss_grids[alg]) for alg in algs])
max_loss = np.max([np.max(loss_grids[alg]) for alg in algs])
norm = Normalize(vmin=min_loss, vmax=max_loss)

assert len(algs) == 4
fig, axes = plt.subplots(2, 2, figsize=(9, 8), constrained_layout=True)
axes = axes.ravel()
x_labels = [str(other_lr) for other_lr in other_lrs]
y_labels = [str(muon_lr) for muon_lr in muon_lrs]
for i, alg in enumerate(algs):
    ax = axes[i]
    heatmap = ax.imshow(loss_grids[alg], cmap="viridis", norm=norm, aspect="auto", origin="lower")
    ax.set_title(DISPLAY_NAMES[alg])
    ax.set_xlabel("$\eta_b$ (Base LR)")
    ax.set_ylabel("$\eta_m$ (Muon LR)")
    ax.set_xticks(np.arange(len(x_labels)))
    ax.set_xticklabels(x_labels)
    ax.set_yticks(np.arange(len(y_labels)))
    ax.set_yticklabels(y_labels)
    if i == 0:
        cbar = fig.colorbar(heatmap, ax=axes, shrink=0.9)
        cbar.set_label("Loss")

plot_path = os.path.join(RESULTS_DIR, f"{RUN_NAME}.pdf")
plt.savefig(plot_path)
