import pandas as pd


def summarise_results(
    dataset: str, metric: str, split_type: str, task: str, all: bool = False
):
    df = pd.read_csv(f"results/{dataset}/baseline_{task}_{dataset}.csv", index_col=0)
    df = df[df["split_type"] == split_type]
    if all:
        df_grouped = df.groupby(["model", "embedding"])[
            [f"train_{metric}", f"val_{metric}", f"test_{metric}"]
        ].agg(["mean", "sem"])
    else:
        df_grouped = df.groupby(["model", "embedding"])[[f"test_{metric}"]].agg(
            ["mean", "sem"]
        )
    return df_grouped


if __name__ == "__main__":
    # Process all datasets
    dataset = "tim"
    metric = "spearman"
    split_type = "holdout"
    task = "regression"
    df = summarise_results(dataset, metric, split_type, task)

    df_all = pd.DataFrame(
        index=["AF2", "ESM-1B", "ESM-2", "ESM-IF1", "EVE (z)", "ONEHOT (MSA)"],
        columns=[
            f"KNN/{metric}/mean",
            f"KNN/{metric}/sem",
            f"Ridge/{metric}/mean",
            f"Ridge/{metric}/sem",
            f"RandomForest/{metric}/mean",
            f"RandomForest/{metric}/sem",
        ],
    )
    for regressor in ["KNN", "Ridge", "RandomForest"]:
        mean = df.loc[(regressor,), (f"test_{metric}", "mean")]
        sem = df.loc[(regressor,), (f"test_{metric}", "sem")]
        df_all.loc[mean.index, f"{regressor}/{metric}/mean"] = mean
        df_all.loc[sem.index, f"{regressor}/{metric}/sem"] = sem

    df_all = df_all.astype(float)
    df_all = df_all.round(decimals=2)
    df_summary = pd.DataFrame(
        index=df_all.index, columns=["KNN", "Ridge", "RandomForest"]
    )
    for regressor in ["KNN", "Ridge", "RandomForest"]:
        df_summary[regressor] = (
            "$"
            + df_all[f"{regressor}/{metric}/mean"].astype(str)
            + " \pm "
            + df_all[f"{regressor}/{metric}/sem"].astype(str)
            + "$"
        )

    df_summary = df_summary.rename(
        columns={
            "KNN": "\multicolumn{1}{c}{KNN}",
            "Ridge": "\multicolumn{1}{c}{Ridge}",
            "RandomForest": "\multicolumn{1}{c}{RandomForest}",
        }
    )

    embedding_names = {
        "ONEHOT (MSA)": "MSA (1-HOT)",
        "ESM-2": "ESM-2",
        "ESM-1B": "ESM-1B",
        "ESM-IF1": "ESM-IF1",
        "EVE (z)": "EVE",
        "AF2": "Evoformer (AF2)",
    }
    df_summary = df_summary.rename(index=embedding_names)
    df_summary = df_summary.sort_index()

    print(
        df_summary.style.to_latex(
            hrules=True,
            column_format="l|ccc",
            caption=f"{dataset}, {task}, {metric}, {split_type}",
            position_float="centering",
        )
    )
