from collections import OrderedDict

import numpy as np
import matplotlib.pyplot as plt


COLOR_MAP = {
    "muon": "#3A86FF", # blue
    "scion": "#595959", # grey-black
    "muonmax": "#EF476F", # light red/pink
    "muonmax-stale": "#EF476F", # light red/pink
    "muon-momo": "#3A86FF", # blue
    "muon-momo-stale": "#3A86FF", # blue
    "scion-momo": "#595959", # grey-black
    "scion-momo-stale": "#595959", # grey-black
    "muonmax-momo": "#EF476F", # light red/pink
    "muonmax-momo-stale": "#EF476F", # light red/pink
    "muon-momo-0.0": "#3A86FF", # blue
    "muon-momo-1.6": "#3A86FF", # blue
    "muon-momo-2.4": "#3A86FF", # blue
    "muon-momo-2.8": "#3A86FF", # blue
    "muon-momo-3.2": "#3A86FF", # blue
    "muonmax-momo-0.0": "#EF476F", # light red/pink
    "muonmax-momo-1.6": "#EF476F", # light red/pink
    "muonmax-momo-2.4": "#EF476F", # light red/pink
    "muonmax-momo-2.8": "#EF476F", # light red/pink
    "muonmax-momo-3.2": "#EF476F", # light red/pink
    "muon-clip": "#3A86FF", # blue
    "muonmax-clip": "#EF476F", # light red/pink
}
STYLE_MAP = {
    "muon": "-",
    "scion": "-",
    "muonmax": "-",
    "muonmax-stale": "-",
    "muon-momo": ":",
    "muon-momo-stale": ":",
    "scion-momo": ":",
    "scion-momo-stale": ":",
    "muonmax-momo": ":",
    "muonmax-momo-stale": ":",
    "muon-clip": "--",
    "muonmax-clip": "--",
}
ALPHA_MAP = {
    "muon-momo-0.0": 0.2,
    "muon-momo-1.6": 0.4,
    "muon-momo-2.4": 0.6,
    "muon-momo-2.8": 0.8,
    "muon-momo-3.2": 1.0,
    "muonmax-momo-0.0": 0.2,
    "muonmax-momo-1.6": 0.4,
    "muonmax-momo-2.4": 0.6,
    "muonmax-momo-2.8": 0.8,
    "muonmax-momo-3.2": 1.0,
}
DISPLAY_NAMES = {
    "muon": "MuonAdam",
    "scion": "Scion",
    "muonmax": "MuonMax",
    "muonmax-stale": "MuonMax",
    "muon-momo": "MuonAdam-Momo",
    "muon-momo-stale": "MuonAdam-Momo",
    "scion-momo": "Scion-Momo",
    "scion-momo-stale": "Scion-Momo",
    "muonmax-momo": "MuonMax-Momo",
    "muonmax-momo-stale": "MuonMax-Momo",
    "muon-momo-0.0": "Muon-Momo ($F_* = 0.0$)",
    "muon-momo-1.6": "Muon-Momo ($F_* = 1.6$)",
    "muon-momo-2.4": "Muon-Momo ($F_* = 2.4$)",
    "muon-momo-2.8": "Muon-Momo ($F_* = 2.8$)",
    "muon-momo-3.2": "Muon-Momo ($F_* = 3.2$)",
    "muonmax-momo-0.0": "MuonMax-Momo ($F_* = 0.0$)",
    "muonmax-momo-1.6": "MuonMax-Momo ($F_* = 1.6$)",
    "muonmax-momo-2.4": "MuonMax-Momo ($F_* = 2.4$)",
    "muonmax-momo-2.8": "MuonMax-Momo ($F_* = 2.8$)",
    "muonmax-momo-3.2": "MuonMax-Momo ($F_* = 3.2$)",
    "muon-clip": "Muon-Clip",
    "muonmax-clip": "MuonMax-Clip",
}

def get_loss_stats(total_results):

    # Get final losses.
    final_losses = OrderedDict()
    for (alg, hp, seed), run_results in total_results.items():
        if (alg, hp) not in final_losses:
            final_losses[(alg, hp)] = []
        final_losses[(alg, hp)].append(run_results["val_losses"][-1])

    # Get mean and std of final losses.
    final_loss_stats = OrderedDict()
    for (alg, hp), run_losses in final_losses.items():
        final_loss_stats[(alg, hp)] = {
            "mean": float(np.mean(run_losses)),
            "std": float(np.std(run_losses)),
            "max": float(np.max(run_losses)),
            "min": float(np.min(run_losses)),
        }

    # Get best setting for each alg.
    tuned_final_losses = OrderedDict()
    for (alg, hp), run_stats in final_loss_stats.items():
        if alg not in tuned_final_losses or run_stats["mean"] < tuned_final_losses[alg]["mean"]:
            tuned_final_losses[alg] = {"best_hp": hp}
            tuned_final_losses[alg].update(run_stats)

    # Smooth losses.
    smoothed_losses = OrderedDict()
    smooth_beta = 0.95
    for (alg, hp, seed), run_results in total_results.items():
        if (alg, hp) not in smoothed_losses:
            smoothed_losses[(alg, hp)] = []
        cur_losses = total_results[(alg, hp, seed)]["losses"]
        smooth_cur_losses = []
        smooth_avg = cur_losses[0]
        for x in cur_losses:
            smooth_avg = smooth_beta * smooth_avg + (1 - smooth_beta) * x
            smooth_cur_losses.append(float(smooth_avg))
        smoothed_losses[(alg, hp)].append(list(smooth_cur_losses))

    # Get mean and std of smoothed losses across seeds.
    smoothed_loss_stats = OrderedDict()
    for (alg, hp) in smoothed_losses:
        cur_smooth_losses = np.array(smoothed_losses[(alg, hp)])
        smoothed_loss_stats[(alg, hp)] = {
            "mean": np.mean(cur_smooth_losses, axis=0),
            "std": np.std(cur_smooth_losses, axis=0),
            "max": np.max(cur_smooth_losses, axis=0),
            "min": np.min(cur_smooth_losses, axis=0),
        }

    return final_losses, final_loss_stats, tuned_final_losses, smoothed_loss_stats


def plot_lr_sensitivity(algs, final_loss_stats, plot_path, xlim=None, ylim=None):

    fig, ax = plt.subplots(figsize=(6, 4))
    lowest = float('inf')
    highest = -float('inf')

    for alg in algs:
        alg_hps = sorted(list(set(hp for (a, hp) in final_loss_stats if a == alg)))
        alg_losses = [final_loss_stats[(alg, hp)]["mean"] for hp in alg_hps]
        color = COLOR_MAP[alg]
        linestyle = STYLE_MAP[alg] if alg in STYLE_MAP else None
        alpha = ALPHA_MAP[alg] if alg in ALPHA_MAP else None
        display_name = DISPLAY_NAMES[alg]
        line_obj, = ax.plot(alg_hps, alg_losses, label=display_name, linewidth=2, color=color, linestyle=linestyle, markersize=6, marker="o", alpha=alpha)
        line_color = line_obj.get_color()

        alg_min_losses = [final_loss_stats[(alg, hp)]["min"] for hp in alg_hps]
        alg_max_losses = [final_loss_stats[(alg, hp)]["max"] for hp in alg_hps]
        ax.fill_between(alg_hps, alg_min_losses, alg_max_losses, color=line_color, alpha=0.2)

        highest = max(highest, np.max(alg_losses))
        lowest = min(lowest, np.min(alg_losses))

    if ylim is None:
        ylim = [lowest * 0.9, highest * 1.1]

    ax.set_xscale("log")
    if xlim is not None:
        ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xlabel("$\eta_m$ (Muon LR)")
    ax.set_ylabel("Final Validation Loss")
    ax.legend(loc="upper right", fontsize=10)
    ax.grid(axis="both", lw=0.2, ls="--", zorder=0)
    fig.subplots_adjust(top=0.95, bottom=0.15, left=0.15, right=0.95)

    fig.savefig(plot_path, format="pdf", bbox_inches="tight")
    plt.close(fig)


def plot_loss_curves(algs, smoothed_loss_stats, tuned_final_losses, plot_path, xlim=None, ylim=None):

    fig, ax = plt.subplots(figsize=(6, 4))
    lowest = float('inf')
    highest = -float('inf')
    bound_start = 0.4

    for alg in algs:
        best_hp = tuned_final_losses[alg]["best_hp"]
        mean_losses = smoothed_loss_stats[(alg, best_hp)]["mean"]
        std_losses = smoothed_loss_stats[(alg, best_hp)]["std"]

        percentage = np.arange(len(mean_losses)) / len(mean_losses)
        color = COLOR_MAP[alg]
        linestyle = STYLE_MAP[alg] if alg in STYLE_MAP else None
        display_name = DISPLAY_NAMES[alg]
        line_obj, = ax.plot(percentage, mean_losses, linewidth=2, label=display_name, color=color, linestyle=linestyle)
        line_color = line_obj.get_color()
        upper_losses = mean_losses + std_losses
        lower_losses = mean_losses - std_losses
        ax.fill_between(percentage, upper_losses, lower_losses, color=line_color, alpha=0.2)

        start = round(len(mean_losses) * bound_start)
        highest = max(highest, np.max(upper_losses[start:]))
        lowest = min(lowest, np.min(lower_losses[start:]))

    if ylim is None:
        ylim = [lowerst * 0.9, highest * 1.1]

    ax.set_xlabel("Epochs")
    ax.set_ylabel("Train Loss")
    if xlim is not None:
        ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.legend(loc="upper right", fontsize=10)
    fig.subplots_adjust(top=0.99, bottom=0.155, left=0.12, right=0.99)

    fig.savefig(plot_path, format="pdf", bbox_inches="tight")
    plt.close(fig)
