import re

import pandas as pd

# Load the CSV file
df = pd.read_csv("results/kge_openbiolink.csv")
# Filter out rows with names containing mixMLP1, mixMLP2, and mixMLP8
# df = df[~df["name"].str.contains("mixMLP1|mixMLP2|mixMLP8")]


# Extract the base model name (keeping KGE type, dimension, and mix information)
def extract_base_name(name):
    # Match the KGE type, dimension, and check for mix
    match = re.match(
        r"^(ConvE|Hadamard|RESCAL|Complex)(_mixMLP4_entsamp0\.001_div0\.1)?_(200D|1000D)",
        name,
    )
    if match:
        kge_type = match.group(1)
        mix_info = "mix" if match.group(2) else ""
        dimension = match.group(3)
        return f"{kge_type}_{mix_info}_{dimension}".replace("__", "_")
    return name


# Add a column with the base model name
df["base_model"] = df["name"].apply(extract_base_name)

# Define the metrics to analyze
metrics = ["NLL", "MRR", "MR", "HITS@1", "HITS@3", "HITS@10"]

# Group by base model and calculate mean and std for each metric
results = []
for metric in metrics:
    # Calculate mean and std for each base model
    grouped = df.groupby("base_model")[metric].agg(["mean", "std"])

    # Format the results
    for model, values in grouped.iterrows():
        results.append(
            {
                "Model": model,
                "Metric": metric,
                "Mean": values["mean"],
                "Std": values["std"],
            },
        )

# Convert to DataFrame for easier display
results_df = pd.DataFrame(results)

# Print the results in a readable format
print("Model Performance Statistics:")
print("============================")
for model in sorted(results_df["Model"].unique()):
    display_model = model.replace("Hadamard", "DistMult")
    print(f"\n{display_model}:")
    model_results = results_df[results_df["Model"] == model]
    for _, row in model_results.iterrows():
        # Format to 3 significant figures using g format with precision 3
        print(f"  {row['Metric']}: {row['Mean']:.3g} ± {row['Std']:.3g}")

# Optionally save the results to a new CSV
results_df.to_csv("results/model_statistics.csv", index=False)
print("\nResults saved to 'results/model_statistics.csv'")
