# scripts/analyze_feature_bank.py
import json, os, re
import pandas as pd
from collections import Counter, defaultdict

BANK_JSON = "results/feature_bank/feature_bank_layer5.json"  # change
BANK_CSV = "results/feature_bank/feature_bank_layer5.csv"  # optional

with open(BANK_JSON) as f:
    bank = json.load(f)

# --- flatten motifs
rows = []
for feat in bank:
    layer = feat["layer"]
    fid = feat["feature_id"]
    cov_reached = feat["coverage_reached"]
    motifs = feat.get("motifs", [])
    if motifs:
        for m in motifs:
            rows.append(
                {
                    "layer": layer,
                    "feature_id": fid,
                    "cov_set": cov_reached,
                    "motif": m["name"],
                    "smarts": m["smarts"],
                    "coverage": m["coverage"],
                    "enrichment": m["enrichment"],
                    "odds": m["odds_ratio"],
                    "p": m["p_value"],
                    "sig": bool(m["fdr_significant"]),
                }
            )
    else:
        rows.append(
            {
                "layer": layer,
                "feature_id": fid,
                "cov_set": cov_reached,
                "motif": "",
                "smarts": "",
                "coverage": 0.0,
                "enrichment": 0.0,
                "odds": 0.0,
                "p": 1.0,
                "sig": False,
            }
        )

df = pd.DataFrame(rows)

# --- 1) coverage / significance stats
feat_has_sig = df.groupby("feature_id")["sig"].max()
feat_single_high = df.groupby("feature_id")["coverage"].max() >= 0.60
print(f"Features with ≥1 FDR motif: {feat_has_sig.mean():.1%}")
print(f"Features with single-motif coverage ≥0.60: {feat_single_high.mean():.1%}")

by_layer = df.groupby("layer").apply(
    lambda g: pd.Series(
        {
            "n_features": g["feature_id"].nunique(),
            "n_interpretable": g.groupby("feature_id")["sig"].max().sum(),
            "mono_like": (g.groupby("feature_id")["coverage"].max() >= 0.60).sum(),
        }
    )
)
print("\nLayer profile:\n", by_layer)

# --- 2) top motifs by quality (filter sig)
top = (
    df[df.sig]
    .assign(score=lambda x: x["coverage"] * (x["enrichment"] + 1).pow(0.5))
    .sort_values(["score", "coverage", "enrichment"], ascending=False)
)
print(
    "\nTop aligned motifs:\n",
    top.head(20)[["motif", "coverage", "enrichment", "feature_id", "layer"]],
)

# --- 3) redundancy: how many features per motif
motif_counts = (
    df[df.sig].groupby("motif")["feature_id"].nunique().sort_values(ascending=False)
)
print("\nMotif redundancy (features per motif):\n", motif_counts.head(20))


# --- 4) “feature cards” for the top 5 motifs (examples / counterexamples)
def feature_card(motif_name, k=5):
    subset = top[top.motif == motif_name].head(k)
    print(f"\n=== {motif_name} ===")
    for _, r in subset.iterrows():
        feat = next(b for b in bank if b["feature_id"] == r["feature_id"])
        ex = feat.get("examples", [])[:3]
        cex = feat.get("counterexamples", [])[:2]
        print(
            f"  f{r.feature_id} L{r.layer}  cov={r.coverage:.2f}  enr={r.enrichment:.1f}"
        )
        print("    examples:       ", "; ".join(ex))
        print("    counterexamples:", "; ".join(cex))


for m in motif_counts.head(5).index:
    feature_card(m)

# --- 5) save summaries
outdir = os.path.join(os.path.dirname(BANK_JSON), "summaries")
os.makedirs(outdir, exist_ok=True)
by_layer.to_csv(os.path.join(outdir, "layer_profile.csv"))
top.head(200).to_csv(os.path.join(outdir, "top_motifs.csv"), index=False)
motif_counts.to_csv(os.path.join(outdir, "motif_redundancy.csv"))
