from matplotlib import pyplot as plt
import seaborn as sns
from experiments.fns import get_project_data
from experiments.fns import rename_row
from experiments.fns import label_row

steps = [99, 199, 299, 399, 499]
# exprs = ["none", "0.0-0.5-0.5-0.0-0.5-0.0-0.5", "0.5-0.5-0.0-0.0-0.5-0.5-0.0"]
# exprs = ["0.0-0.5-0.5-0.0-0.5-0.0-0.5"]
exprs = ["none", "0.0-0.5-0.5-0.0-0.5-0.0-0.5"]
exprs = [exp + " (BMM0) Adam" for exp in exprs]
target_var = "train_loss_avg"
# target_var = "test_error"
df = get_project_data(project="cifar_baselines", steps=steps)
df = df[df["state"] == "finished"]
df["vec"] = df.apply(rename_row, axis=1)
df = df[df["vec"].isin(exprs)]
df["label"] = df.apply(label_row, axis=1)

# df2 = get_project_data(project="btt3_shuf", steps=steps)
# df2["vec"] = df2.apply(rename_row, axis=1)
# df2["label"] = df2.apply(label_row, axis=1)
# df = pd.concat((df, df2))

df = get_project_data(project="btt3_shuf", steps=steps)
df["vec"] = df.apply(rename_row, axis=1)
df["label"] = df.apply(label_row, axis=1)

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
plt.figure(dpi=75, figsize=(25, 15))
sns.set_palette("Set2")
for step in steps:
    dff = df[df["epoch"] == step]
    dff = dff.loc[dff.groupby(["vec", "width"])[target_var].idxmin()]
    dff["label"] += f" (E={step + 1})"
    sns.scatterplot(x="cola_flops", y=target_var, data=dff, style="label", s=200)
    sns.lineplot(x="cola_flops", y=target_var, data=dff, style="label")

plt.ylabel("Train Loss" if target_var.startswith("train_loss") else "Test Error")
plt.xlabel('FLOPs')
plt.xscale('log')
plt.yscale('log')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.ylim([1.7, 2.2] if target_var.find("test") < -1 else None)
plt.tight_layout()
plt.show()
