import os
import glob
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


def natural_key(s):
    return [int(t) if t.isdigit() else t for t in re.split(r"(\d+)", s)]


GLOBS = {
    # "Channel 1": "/home/wmar/wmar_audio/outputs/wm_generations/wm_ablation_c1/results_*.csv",
    # "Channel 2": "/home/wmar/wmar_audio/outputs/wm_generations/wm_ablation_c2/results_*.csv",
    # "Channel 3": "/home/wmar/wmar_audio/outputs/wm_generations/wm_ablation_c3/results_*.csv",
    # "Channel 4": "/home/wmar/wmar_audio/outputs/wm_generations/wm_ablation_c4/results_*.csv",
    # "Channel 5": "/home/wmar/wmar_audio/outputs/wm_generations/wm_ablation_c5/results_*.csv",
    # "Channel 6": "/home/wmar/wmar_audio/outputs/wm_generations/wm_ablation_c6/results_*.csv",
    # "Channel 7": "/home/wmar/wmar_audio/outputs/wm_generations/wm_ablation_c7/results_*.csv",
    # "Channel 8": "/home/wmar/wmar_audio/outputs/wm_generations/wm_ablation_c8/results_*.csv",

    "Channel 0": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/wm_ablation_c0_fix/results_*.csv",
    "Channel 1": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/wm_ablation_c1_fix/results_*.csv",
    "Channel 2": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/wm_ablation_c2_fix/results_*.csv",
    "Channel 3": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/wm_ablation_c3_fix/results_*.csv",
}

rows = []

for name, g in GLOBS.items():
    group = os.path.basename(os.path.dirname(g))
    for path in sorted(
        glob.glob(g),
        key=lambda p: (
            "standard" not in os.path.basename(p),
            natural_key(os.path.basename(p))
        )
    ):
        df = pd.read_csv(path)
        rows.append({
            "group": name,
            "method": os.path.basename(path).replace("results_", "").replace(".csv", "").replace("standard", "baseline"),
            "median_pval": df["pval"].median(),
            "mean_logpval": df["logpval"].mean(),
        })

out = pd.DataFrame(rows).set_index(["group", "method"])
out = (
    out
    .groupby(level=0, group_keys=False)
    .apply(lambda df: df.sort_values("mean_logpval", ascending=False).head(10))
)

# -------- terminal --------
print(out.to_string(formatters={
    "median_pval": lambda x: f"{x:.3e}",
    "mean_logpval": lambda x: f"{x:.3f}",
}))

raise

# Regex to parse method names like: leiden_10_res0.8
PAT = re.compile(r"leiden_(\d+)_res([0-9.]+)")

OUTDIR = "/home/wmar/wmar_audio/outputs/ablation"
os.makedirs(OUTDIR, exist_ok=True)

channel_heatmaps = {} 

for channel_name, pattern in GLOBS.items():
    files = sorted(glob.glob(pattern))
    if not files:
        print(f"[WARN] no files for {channel_name} (pattern: {pattern})")
        continue

    # collect mapping (res, cnt) -> mean_logpval
    table = {}
    cnts = set()
    ress = set()

    for p in files:
        basename = os.path.basename(p)
        method = basename.replace("results_", "").replace(".csv", "")
        # skip baseline / non-leiden entries
        m = PAT.match(method)
        if not m:
            # you can choose to handle baseline separately; here we skip non-leiden
            # print(f"Skipping non-leiden result: {basename}")
            continue

        cnt = int(m.group(1))
        res = float(m.group(2))

        df = pd.read_csv(p)
        if "logpval" not in df.columns:
            raise RuntimeError(f"{p} missing 'logpval' column")

        mean_logpval = float(df["logpval"].mean())

        table[(res, cnt)] = mean_logpval
        cnts.add(cnt)
        ress.add(res)

    if not table:
        print(f"[WARN] no leiden entries parsed for {channel_name}")
        continue

    cnts = sorted(cnts)
    ress = sorted(ress)

    # build DataFrame with rows=resolutions and columns=min_counts
    heat_df = pd.DataFrame(index=ress, columns=cnts, dtype=float)
    heat_df[:] = np.nan
    for (res, cnt), v in table.items():
        heat_df.at[res, cnt] = v

    channel_heatmaps[channel_name] = heat_df

    # Plot heatmap
    plt.figure(figsize=(max(4, len(cnts) * 0.6), max(3, len(ress) * 0.6)))

    ax = sns.heatmap(
        heat_df,
        cmap="viridis",
        annot=True,
        fmt=".3f",
        linewidths=0.5,
        linecolor="gray",
        cbar_kws={"label": "mean_logpval"},
        square=False,
        xticklabels=cnts,
        yticklabels=ress,
    )

    ax.set_xlabel("min_count")
    ax.set_ylabel("resolution")
    ax.set_title(f"{channel_name} — mean_logpval (leiden)")

    # Improve layout
    plt.tight_layout()
    out_path = os.path.join(OUTDIR, f"heatmap_{channel_name.replace(' ', '_')}.png")
    plt.savefig(out_path, dpi=200)
    plt.close()

    # Also save the numeric table for inspection
    csv_out = os.path.join(OUTDIR, f"table_{channel_name.replace(' ', '_')}.csv")
    heat_df.to_csv(csv_out)

    print(f"Saved heatmap: {out_path} and table: {csv_out}")


chosen_options = [
    list(GLOBS.keys())[:4],  # first 4 channels
    list(GLOBS.keys())       # all channels
]

topk = 8
for option in chosen_options:
    dfs = [channel_heatmaps[c] for c in option]
    stack = np.stack([df.values for df in dfs])       # (n_channels, n_res, n_cnt)
    agg_arr = stack.mean(axis=0)                      # mean across channels
    agg_df = pd.DataFrame(agg_arr, index=dfs[0].index, columns=dfs[0].columns)

    # print top-k combos
    best = agg_df.stack().sort_values(ascending=False).head(topk)
    print(f"\nAggregated over {len(option)} channels: {option}")
    print(best.to_string())

    # plot heatmap and mark top-k
    plt.figure(figsize=(max(4, len(agg_df.columns)*0.6), max(3, len(agg_df.index)*0.6)))
    ax = sns.heatmap(agg_df, cmap="viridis", cbar_kws={"label": "mean_logpval"})
    ax.set_xlabel("min_count")
    ax.set_ylabel("resolution")
    ax.set_title(f"Aggregated mean_logpval — {len(option)} channels")

    # mark top-k with red X
    for (res, cnt), val in best.items():
        row = list(agg_df.index).index(res)
        col = list(agg_df.columns).index(cnt)
        ax.scatter(col + 0.5, row + 0.5, color="red", marker="x", s=80, linewidths=2)

    plt.tight_layout()
    suffix = "first4" if len(option) == 4 else "all"
    out_path = os.path.join(OUTDIR, f"heatmap_agg_{suffix}.png")
    plt.savefig(out_path, dpi=200)
    plt.close()
    agg_df.to_csv(os.path.join(OUTDIR, f"table_agg_{suffix}.csv"))
    print(f"Saved: {out_path}")
