import os
import glob
import pandas as pd
import re


def natural_key(s):
    return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]


GLOBS = {
    "Meta": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_meta_full/results_*.csv",
    "Clusters": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_full_fix/results_*.csv",
}

rows = []

for name, g in GLOBS.items():
    for path in sorted(
        glob.glob(g),
        key=lambda p: (
            "standard" not in os.path.basename(p),
            natural_key(os.path.basename(p))
        )
    ):
        df = pd.read_csv(path)

        # ---- per-augmentation aggregation ----
        per_aug = (
            df.groupby("aug_name")
              .agg(
                  median_pval=("pval", "median"),
                  mean_logpval=("logpval", "mean"),
              )
              .reset_index()
        )

        method = (
            os.path.basename(path)
            .replace("results_", "")
            .replace(".csv", "")
            .replace("standard", "baseline")
        )

        for _, r in per_aug.iterrows():
            rows.append({
                "group": name,
                "method": method,
                "aug_name": r["aug_name"],
                "median_pval": r["median_pval"],
                "mean_logpval": r["mean_logpval"],
            })

out = pd.DataFrame(rows).set_index(["group", "method", "aug_name"])

# -------- terminal --------
print(out.to_string(formatters={
    "median_pval": lambda x: f"{x:.3e}",
    "mean_logpval": lambda x: f"{x:.3f}",
}))

# -------- latex --------
latex = out.reset_index().to_latex(
    index=False,
    column_format="lllcc",
    caption="Median p-values and mean log p-values per method and augmentation.",
    label="tab:pvals_per_aug",
    formatters={
        "median_pval": lambda x: f"{x:.3e}",
        "mean_logpval": lambda x: f"{x:.3f}",
    }
).replace("_", "-")

print("\n" + latex)
