import os
from os.path import join
import re
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from trainkit.saving import load_object

iters_divs = (4 / 2, 8 / 200, 4 / 200, 4 / 200, 4 / 200, 4 / 200, 4 / 200)
cases = ["BTT(fp16)", "BTT(fp32)", "Dense(fp16)", "BTT(fp16)-qk", "BTT(fp16)-gamma", "BTT(fp16)-wd", "BTT(fp16)-norm"]
part = "sigmas"
x_max, y_max = 6_000, -1
# part = "act_norms"
# x_max, y_max = 6_000, 1_000
root_dir = "Downloads/assets_struct"
filepaths = []
filepaths.append(join(root_dir, f"out-wiki103-btt_2024-01-19_182517/{part}.pkl"))
filepaths.append(join(root_dir, f"out-wiki103-btt_2024-01-19_224248/{part}.pkl"))
filepaths.append(join(root_dir, f"out-wiki103-dense_2024-01-19_210253/{part}.pkl"))
filepaths.append(join(root_dir, f"out-wiki103-btt_spect_2024-01-20_220729/{part}.pkl"))
filepaths.append(join(root_dir, f"out-wiki103-btt_spect_2024-01-22_030241/{part}.pkl"))
filepaths.append(join(root_dir, f"out-wiki103-btt_spect_2024-01-21_120444/{part}.pkl"))
filepaths.append(join(root_dir, f"out-wiki103-btt_spect_2024-01-22_074242/{part}.pkl"))
# posfixes = [f"{i}.mlp.c_proj" for i in range(12)]
posfixes = [f"{i}.mlp.c_fc" for i in range(12)]
# posfixes = [f"{i}.attn.c_attn" for i in range(12)]
# posfixes = [f"{i}.attn.c_proj" for i in range(12)]


def get_plot_labels(path):
    map = {"btt": "BTT", "act_norms": "act norms (mean)", "sigmas": "sigma", "dense": "Dense"}
    matches = re.findall(r"(dense|btt|act_norms|sigmas)", path)
    ylabel, title = map[matches[-1]], map[matches[0]]
    return ylabel, title


results = [load_object(join(os.environ["HOME"], filepath)) for filepath in filepaths]
sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
# pal = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3']
# pal = ['#1b9e77', '#d95f02', '#7570b3', '#e7298a', '#66a61e', '#e6ab02']
# pal = ['#8c510a', '#d8b365', '#f6e8c3', '#f5f5f5', '#c7eae5', '#5ab4ac', '#01665e']
pal = ['#1b9e77', '#d95f02', '#7570b3', '#e7298a', '#66a61e', '#e6ab02', '#a6761d']
pal = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628']
sns.set_palette(sns.color_palette(pal))
for posfix in posfixes:
    plt.figure(dpi=100, figsize=(16, 8))
    for filepath, result, iters_div, case in zip(filepaths, results, iters_divs, cases):
        ylabel, _ = get_plot_labels(filepath)
        layer_name, pack = [(la, pa) for la, pa in result.items() if la.endswith(posfix)][0]
        sigmas = [s for s, *_ in pack]
        iters = np.arange(len(sigmas)) / iters_div
        print(layer_name)
        plt.title(layer_name[21:])
        plt.plot(iters, sigmas, label=case)
        plt.scatter(iters, sigmas)
        plt.xlabel("Iters")
        plt.ylabel(ylabel)
        plt.xlim([0, x_max])
        if y_max > 0:
            plt.ylim([0, y_max])
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
        plt.tight_layout()
    plt.show()
