from all_data import get_all_data
import pandas as pd

include_mlp = True  # Set to False if you want to exclude MLP scores
df = get_all_data(include_mlp=include_mlp)
df = df[df["batch_id"] == "37290af5-1d7d-4d35-88fe-6fcf8ffc5868"]
df = df.drop("batch_id", axis=1)
df = df[~df["is_train"]]
df = df.drop("is_train", axis=1)
df = df.drop("instance", axis=1)


df_means = df[["domain", "prost", "mlp", "score"]].groupby(["domain"]).mean()
df_means = df_means.reset_index()

df_stds = df[["domain", "prost", "mlp", "score"]].groupby(["domain"]).std()
df_stds = df_stds.reset_index()

full_df = df_means.merge(df_stds, on="domain", suffixes=("", "_std"))

for i, row in full_df.iterrows():
    domain = row["domain"]
    gnn = row["score"]
    mlp = row["mlp"]
    prost = row["prost"]
    gnn_std = row["score_std"]
    mlp_std = row["mlp_std"]
    prost_std = row["prost_std"]
    gnn = f"{gnn:.2f} $\\pm$ {gnn_std:.2f}"
    mlp = f"{mlp:.2f} $\\pm$ {mlp_std:.2f}"
    prost = f"{prost:.2f} $\\pm$ {prost_std:.2f}"
    print(
        f"{domain} & {gnn} & {mlp} & {prost} \\\\"
        if include_mlp
        else f"{domain} & {gnn} & {prost} \\\\"
    )
