from os.path import join
import pandas as pd
from os import listdir
from matplotlib import pyplot as plt
import seaborn as sns
from experiments.fns import check_cond
from experiments.fns import get_details

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set1")

sweep_dir = "./logs/kron"


def check_kron_mlp(filepath):
    return check_cond(filepath, key="model", val="kron_mlp")


paths = [join(sweep_dir, d) for d in listdir(sweep_dir) if check_kron_mlp(join(sweep_dir, d))]
data = []
for path in paths:
    data += get_details(path, key="hidden_size", col="TrA")
df = pd.DataFrame(data, columns=("Epoch", "TrAcc", "Metric", "Size"))
sizes = sorted(list(df["Size"].unique()))

plt.figure(dpi=100, figsize=(10, 8))
for size in sizes:
    df_fil = df[df["Size"] == size]
    plt.plot(df_fil["Epoch"], df_fil["TrAcc"], label=f"Kron({size})")
    plt.scatter(df_fil["Epoch"], df_fil["TrAcc"])

plt.xlabel("Epoch")
plt.ylabel("Train Accuracy")
plt.legend()
plt.ylim([90, 100.5])
plt.tight_layout()
plt.show()
