import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

# filepath = "./logs/wandb_export_2024-04-04T10_55_36.777-04_00.csv"
filepath = "./logs/wandb_export_2024-04-04T13_41_04.973-04_00.csv"
df = pd.read_csv(filepath)
# df = df[df["optimizer"] == "sgd"]
df = df[df["optimizer"] == "adamw"]
df = df[df["expr0"] == "eg,bdg,adeg->abd"]
# df = df[df["expr0"] == "eg,g,aeg->a"]
# df = df[df["epoch"] > 198]
# df["train_loss"] = df["train_loss"].fillna(100)

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=100, figsize=(20, 10))
# sns.lineplot(x="lr", y="test_acc", data=df, style="expr0", hue="width")
# sns.scatterplot(x="lr", y="test_acc", data=df, style="expr0", hue="width", s=200)
sns.lineplot(x="lr", y="train_loss", data=df, style="expr0", hue="width")
sns.scatterplot(x="lr", y="train_loss", data=df, style="expr0", hue="width", s=200)
plt.ylabel("Train Loss")
plt.xlabel("lr")
plt.xscale("log")
# plt.yscale("log")
# plt.ylim([1.9, 2.1])
plt.ylim([1.95, 2.35])
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
