import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")

df = pd.read_csv("./logs/wandb_export_2023-12-04T12_30_43.602-05_00.csv")
mask = (df["State"] == "finished") & (df["Hostname"] == "mint")
df = df[mask]
ranks = list(df["tt_rank"].unique())
ranks.sort()
# marks = {2: "1", 3: "X", 4: "o"}

cores_n = 4

plt.figure(dpi=100, figsize=(14, 8))
plt.title(f"Cores: {cores_n}")
for rank in ranks:
    count = 0
    dff = df[(df["tt_rank"] == rank) & (df["tt_dim"] == cores_n)]
    # for key, val in marks.items():
    #     dff = df[(df["tt_rank"] == rank) & (df["tt_dim"] == 2)]
    #     plt.scatter(dff["cola_flops"], dff["test_acc"], label=label, marker=val)
    label = f"rank={rank}" if count == 0 else ""
    plt.scatter(dff["cola_flops"], dff["test_acc"], label=label)
    count += 1
plt.xlabel("FLOPs")
plt.ylabel("Test Accuracy")
plt.ylim([50, 85])
plt.xscale("log")
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
