import os
import re
from os.path import join
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from trainkit.saving import load_object

iters_div = 4 / 200
# part = "sigmas"
part = "act_norms"
y_max = 10_000
cases = ["BTT(fp16)", "BTT(fp32)", "Dense(fp16)", "BTT(fp16)-qk"]
filepaths = []
filepaths.append(f"Downloads/out-wiki103-btt_2024-01-19_182517/{part}.pkl")
filepaths.append(f"Downloads/out-wiki103-btt_2024-01-19_224248/{part}.pkl")
filepaths.append(f"Downloads/out-wiki103-dense_2024-01-19_210253/{part}.pkl")
filepaths.append(f"Downloads/out-wiki103-btt_spect_2024-01-20_220729/{part}.pkl")
idx = 1
filepath, case = filepaths[idx], cases[idx]
# posfix = "5.attn.c_attn"
# posfix = "5.mlp.c_fc"
# posfix = "mlp.c_fc"
posfix = "mlp.c_proj"
# posfix = "c_attn"
# posfix = "attn.c_proj"


def get_plot_labels(path):
    map = {"btt": "BTT", "act_norms": "act norms (mean)", "sigmas": "sigma", "dense": "Dense"}
    matches = re.findall(r"(dense|btt|act_norms|sigmas)", path)
    ylabel, title = map[matches[-1]], map[matches[0]]
    return ylabel, title


results = load_object(join(os.environ["HOME"], filepath))
ylabel, title = get_plot_labels(filepath)
sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
greenes = ['#edf8e9', '#c7e9c0', '#a1d99b', '#74c476', '#41ab5d', '#238b45', '#005a32'][1:]
greys = ['#f7f7f7', '#d9d9d9', '#bdbdbd', '#969696', '#737373', '#525252', '#252525'][1:]
pal = [element for pair in zip(greenes, greys) for element in pair]
sns.set_palette(sns.color_palette(pal))
plt.figure(dpi=100, figsize=(16, 8))
plt.title(case)
for layer_name, pack in results.items():
    sigmas = [s for s, *_ in pack]
    iters = np.arange(len(sigmas)) / iters_div
    print(layer_name)
    if layer_name.endswith(posfix):
        plt.plot(iters, sigmas, label=layer_name[21:])
        plt.scatter(iters, sigmas)
        plt.xlabel("Iters")
        plt.ylabel(ylabel)
        if y_max > 0:
            plt.ylim([0, y_max])
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
        plt.tight_layout()
plt.show()
