from matplotlib import pyplot as plt
import seaborn as sns
from experiments.fns import get_project_data

df = get_project_data(project="lr_land", steps=[-1])
target_var = "test_error"
print(df["expr"].unique())
df = df[~df["use_wrong_mult"]]
# df = df[df["use_wrong_mult"]]
# exprs = ["(0.5|0|0.5|0|0.5|0.5|0)"]
# exprs = ["(0.7|0|0.3|0|0.7|0.3|0)"]
exprs = ["(0.333|0.333|0.334|0.33|0.33|0.34|0.33)"]
# exprs = ["(0.8|0.2|0|0.2|0.8|0|0.33)"]
df = df[df["expr"].isin(exprs)]

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=100, figsize=(20, 10))
sns.lineplot(x="lr", y=target_var, data=df, style="expr", hue="width")
sns.scatterplot(x="lr", y=target_var, data=df, style="expr", hue="width", s=200)
plt.ylabel("Train Loss" if target_var.find("train_loss") > 0 else "Test Error")
plt.xlabel("lr")
plt.xscale("log")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
