import pandas as pd
import seaborn as sns
import numpy as np
from matplotlib import pyplot as plt
from trainkit.saving import load_object
from experiments.fns import scale_laws_fn_log

runs = pd.read_csv("./logs/mlp_scale.csv")

structs = list(runs["struct"].unique())
coeffs = load_object("./logs/coeffs_mlp.pkl")
# x = np.linspace(runs["cola_flops"].min(), runs["cola_flops"].max())
x = np.linspace(2e5, 5e8)

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
pal = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628']
sns.set_palette(sns.color_palette(pal))
plt.figure(dpi=100, figsize=(12, 6))
for struct in structs:
    mask = runs["struct"] == struct
    df = runs[mask]
    # plt.scatter(df["cola_flops"].values, df["test_error"].values, label=struct)
    plt.scatter(df["cola_flops"].values, df["train_error_avg"].values, label=struct)
    coeff = coeffs[struct]
    logy = scale_laws_fn_log(coeff, np.log(x))
    plt.plot(x, np.exp(logy) + coeffs["offset"])
# plt.ylabel('Test Error')
plt.ylabel('Train Error Avg')
plt.xlabel('FLOPs')
plt.xscale('log')
plt.yscale('log')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.tight_layout()
plt.show()
