import os
import glob
import pandas as pd
import re


def natural_key(s):
    return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]


GLOBS = {
    # "Meta (aug)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_meta_test/results_*.csv",
    # "Meta": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_meta_noaug_test/results_*.csv",
    # "Clustering 1": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_1/results_*.csv",
    # "Clustering 12": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_12/results_*.csv",
    # "Clustering 123": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_123/results_*.csv",
    # "Clustering 1234": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test/results_*.csv",
    # "Clustering 2": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_2/results_*.csv",
    # "Clustering 23": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_23/results_*.csv",
    # "Clustering 234": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_234/results_*.csv",
    # "Clustering 3": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_3/results_*.csv",
    # "Clustering 4": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_4/results_*.csv",

    # # These are not valid results
    # # "Meta (noaug)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_meta_noaug_test_h1/results_*.csv",
    # # "Clustering": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_h1/results_*.csv",

    # "Meta (noaug) fix (h=1)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_meta_noaug_test_h1_fix/results_*.csv",
    # "Clustering fix (h=1)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_h1_fix/results_*.csv",

    # "Meta (noaug) fix (h=2)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_meta_noaug_test_h2_fix/results_*.csv",
    # "Clustering fix (h=2)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_h2_fix/results_*.csv",

    # "Clustering test-split": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_testsplit/results_*.csv",

    # "Meta h=1 (1st stream ctx)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_meta_noaug_test_h1_ctx_per_channel/results_*.csv",
    # "Clustering h=1 (1st stream ctx)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_h1_ctx_per_channel/results_*.csv",

    "Select": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_select/results_*.csv",
    "Select (different seed)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_select_seed27/results_*.csv",
    "Verify (different seed)": "/home/wmar/wmar_audio/outputs/wm_generations/wm_eval_clusters_test_verify_seed27/results_*.csv"
}

rows = []

for name, g in GLOBS.items():
    group = os.path.basename(os.path.dirname(g))
    for path in sorted(
        glob.glob(g),
        key=lambda p: (
            "standard" not in os.path.basename(p),
            natural_key(os.path.basename(p))
        )
    ):
        df = pd.read_csv(path)
        rows.append({
            "group": name,
            "method": os.path.basename(path).replace("results_", "").replace(".csv", "").replace("standard", "baseline"),
            "median_pval": df["pval"].median(),
            "mean_logpval": df["logpval"].mean(),
        })

out = pd.DataFrame(rows).set_index(["group", "method"])
out = (
    out
    .groupby(level=0, group_keys=False)
    .apply(lambda df: df.sort_values("mean_logpval", ascending=False))#.head(3))
)

# -------- terminal --------
print(out.to_string(formatters={
    "median_pval": lambda x: f"{x:.3e}",
    "mean_logpval": lambda x: f"{x:.3f}",
}))

# -------- latex --------
latex = out.reset_index().to_latex(
    index=False,
    column_format="llcc",
    caption="Median p-values and mean log p-values per method and directory.",
    label="tab:pvals_multi",
    formatters={
        "median_pval": lambda x: f"{x:.3e}",
        "mean_logpval": lambda x: f"{x:.3f}",
    }
).replace("_", "-")

print("\n" + latex)
