import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

filepath = "./logs/wandb_export_2024-03-27T20_54_39.320-04_00.csv"
df = pd.read_csv(filepath)
dff = df[df["struct"].isin(["btt", "multi", "simple"])]

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=100, figsize=(20, 10))
sns.scatterplot(x="cola_flops", y="test_acc", data=dff, style="struct", hue="depth", s=200)
plt.ylabel("Test Acc")
plt.xlabel('FLOPs')
plt.xscale('log')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
