import pandas as pd


# df = pd.read_csv("./parsed_with_rel_reg.csv")


df = pd.read_csv("./data_mode_switch_results.csv", low_memory=False)
# df_mix_tf = pd.read_csv("./mixed_tf_results.csv")
# df = pd.concat([df, df_mix_tf])
df = df[df["Subset"] == "Test"]
df["Use Encoder"] = df["Encoder"].notna()
df = df[df["Classifier"] != "GaussianNB"]
df = df[~df["Dataset"].isin(["RoadSafety","Electricity"])]
df["Encoder"] = df["Encoder"].replace(
    {
        "MLP": "FFN",
        "TF-MultiHead": "TF*",
        "GNN-MultiHead": "GNN*",
        "MLP-MultiHead": "FFN*",
    }
)
df["Distill Method"] = df["Distill Method"].replace(
    {
        "KMeans": "KM",
        "Agglo": "AG",
    }
)

TOP_N = 10

to_cat = []
for (dset, clf), grp in df.groupby(["Dataset", "Classifier"]):
    ori_sco = grp[
        (grp["Distill Method"] == "Original") & (grp["Data Parse Mode"] == "mixed")
    ]["Score"].mean()
    if ori_sco == 1.0 or ori_sco < 0.7:
        continue
    grp["Ori Regret"] = ori_sco - grp["Score"]
    to_cat.append(grp)
w_reg = pd.concat(to_cat)
tt = (
    w_reg[
        (w_reg["Classifier"] == "XGBClassifier")
        & (w_reg["N"] == 100)
        & (w_reg["Distill Method"] == "KM")
        & (w_reg["Encoder"] == "TF*")
        & (w_reg["Data Parse Mode"] == "onehot")
    ]
    .groupby(["Short Name", "Dataset"])[["Ori Regret", "Score"]]
    .mean()
    .reset_index()
).sort_values("Ori Regret")
dsets = tt["Dataset"].unique()[:TOP_N]
print(f"**Top {TOP_N} config for XGB**")
for d in dsets:
    by_reg = tt[tt["Dataset"] == d].sort_values("Ori Regret")
    method = by_reg["Short Name"].values[0]
    regret_score = by_reg["Ori Regret"].values[0]
    run_score = by_reg["Score"].values[0]
    print(
        f"- {d} / `{method}` -- Original: {run_score+regret_score:.4f}, Run: {run_score:.4f}"
    )
print()


to_cat = []
for (dset, clf), grp in df.groupby(["Dataset", "Classifier"]):
    ori_sco = grp[
        (grp["Distill Method"] == "Original") & (grp["Data Parse Mode"] == "mixed")
    ]["Score"].mean()
    if ori_sco == 1.0 or ori_sco < 0.7:
        continue
    grp["Ori Regret"] = ori_sco - grp["Score"]
    to_cat.append(grp)
w_reg = pd.concat(to_cat)
tt = (
    w_reg[
        (w_reg["Classifier"] == "LogisticRegression")
        & (w_reg["N"] == 100)
        & (w_reg["Distill Method"] == "KM")
        & (w_reg["Encoder"] == "TF*")
        & (w_reg["Data Parse Mode"] == "onehot")
    ]
    .groupby(["Short Name", "Dataset"])[["Ori Regret", "Score"]]
    .mean()
    .reset_index()
).sort_values("Ori Regret")
dsets = tt["Dataset"].unique()[:TOP_N]
print(f"**Top {TOP_N} config for LR**")
for d in dsets:
    by_reg = tt[tt["Dataset"] == d].sort_values("Ori Regret")
    method = by_reg["Short Name"].values[0]
    regret_score = by_reg["Ori Regret"].values[0]
    run_score = by_reg["Score"].values[0]
    print(
        f"- {d} / `{method}` -- Original: {run_score+regret_score:.4f}, Run: {run_score:.4f}"
    )
print()
