# scripts/make_pathway_radars.py  (v3: output 3 radars incl. consensus)
import re, pathlib, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

try:
    import gseapy as gp
except ModuleNotFoundError:
    raise SystemExit("缺少 gseapy，请先: pip install gseapy matplotlib")

ROOT = pathlib.Path(__file__).resolve().parents[1]
PROC = ROOT / "data" / "processed"
FIGS = ROOT / "figures"
FIGS.mkdir(parents=True, exist_ok=True)

def _detect_col(cols, candidates):
    cl = {c.lower(): c for c in cols}
    for k in candidates:
        if k in cl: return cl[k]
    return None

def load_de_table(path: pathlib.Path):
    df = pd.read_csv(path)
    gene_col = _detect_col(df.columns, ["gene","genes","symbol","gene_symbol","id","name"]) or df.columns[0]
    lfc_col  = _detect_col(df.columns, ["log2fc","logfc","lfc","log2_fc"])
    if lfc_col is None:
        raise ValueError(f"{path.name}: 未找到 log2FC 列（现有列: {list(df.columns)}）")
    p_col = _detect_col(df.columns, ["fdr","padj","adj_p","adj.p","q","qval","p_adj","p.adj","pval","p"])
    out = pd.DataFrame({
        "gene": df[gene_col].astype(str).str.strip().str.upper(),
        "lfc":  pd.to_numeric(df[lfc_col], errors="coerce")
    })
    out["p"] = pd.to_numeric(df[p_col], errors="coerce") if p_col else np.nan
    return out.dropna(subset=["gene"])

def gene_score_table(de: pd.DataFrame) -> pd.Series:
    lfc = de["lfc"].clip(-3, 3)
    w = (-np.log10(de["p"].clip(lower=1e-300))).clip(0, 5)
    w = w.fillna(1.0)
    score = np.sign(lfc) * np.minimum(np.abs(lfc), 1.0) * w
    s = pd.DataFrame({"gene": de["gene"], "score": score}).groupby("gene", as_index=False)["score"].mean()
    return s.set_index("gene")["score"]

LIB_CANDIDATES = {
    "KEGG":       ["KEGG_2021_Human", "KEGG_2019_Human", "KEGG_2016"],
    "REACTOME":   ["Reactome_2022", "Reactome_2016"],
    "WIKIPATH":   ["WikiPathways_2023_Human", "WikiPathways_2021_Human", "WikiPathways_2019_Human"],
    "GO_BP":      ["GO_Biological_Process_2023", "GO_Biological_Process_2021", "GO_Biological_Process_2018"],
}

AXIS_DEF = {
    "mTOR":       [("KEGG",     r"\bmTOR\b|PI3K.*MTOR")],
    "AMPK":       [("KEGG",     r"\bAMPK signaling\b"), ("GO_BP", r"\bAMP-activated protein kinase\b")],
    "SIRT":       [("WIKIPATH", r"\bSirtu(in|ins)\b"), ("REACTOME", r"\bSIRT"), ("GO_BP", r"\bsirtuin")],
    "FOXO":       [("KEGG",     r"\bFOXO signaling\b"), ("GO_BP", r"\bFOXO")],
    "Autophagy":  [("KEGG",     r"\bAutophagy\b"), ("REACTOME", r"\bAutophagy\b"), ("GO_BP", r"\bautophagy")],
    "MitoBio":    [("GO_BP",    r"\bmitochondrial biogenesis\b|positive regulation of mitochondrial biogenesis|mitochondrial gene expression")],
    "Inflammation":[("GO_BP",   r"\binflammatory response\b|regulation of inflammatory")],
    "Senescence": [("KEGG",     r"\bCellular senescence\b"), ("REACTOME", r"\bCellular Senescence\b"), ("GO_BP", r"\bsenescence")],
}

_LIB_CACHE: dict[str, dict] = {}
def fetch_library_any(group_key: str) -> dict:
    if group_key not in LIB_CANDIDATES: return {}
    if group_key in _LIB_CACHE: return _LIB_CACHE[group_key]
    last_err = None
    for name in LIB_CANDIDATES[group_key]:
        try:
            lib = gp.get_library(name=name)
            if isinstance(lib, dict) and lib:
                print(f"[lib] using {name}")
                _LIB_CACHE[group_key] = lib
                return lib
        except Exception as e:
            last_err = e
            continue
    print(f"[warn] no available library for '{group_key}' (last err: {last_err})")
    _LIB_CACHE[group_key] = {}
    return {}

def union_genesets(specs):
    genes = set(); matched = []
    for group_key, pat in specs:
        libdict = fetch_library_any(group_key)
        if not libdict: continue
        rx = re.compile(pat, flags=re.IGNORECASE)
        for term, glist in libdict.items():
            if rx.search(term):
                genes.update(g.upper() for g in glist)
                matched.append((group_key, term))
    return genes, matched

def pathway_panel_scores(gene_scores: pd.Series):
    base_scale = np.nanpercentile(np.abs(gene_scores.values), 95)
    if not np.isfinite(base_scale) or base_scale <= 0: base_scale = 1.0
    axes, vals, meta = [], [], {}
    for axis, specs in AXIS_DEF.items():
        gs, matched = union_genesets(specs)
        inter = gs.intersection(set(gene_scores.index))
        if len(inter) == 0:
            val = 0.0
        else:
            v = gene_scores.loc[list(inter)].mean()
            val = float(np.tanh(v / (2.0 * base_scale)))  # [-1,1]
        axes.append(axis); vals.append(val)
        meta[axis] = {"n_lib_terms": len(matched), "n_genes_total": len(gs), "n_hit": len(inter)}
    return axes, vals, meta

def radar_plot(axes, vals, title, out_png):
    N = len(axes)
    ang = np.linspace(0, 2*np.pi, N, endpoint=False)
    vals = np.array(vals, dtype=float)
    vals = np.concatenate([vals, vals[:1]]); angp = np.concatenate([ang, ang[:1]])
    fig = plt.figure(figsize=(7,7), dpi=200)
    ax = plt.subplot(111, polar=True)
    ax.plot(angp, vals, lw=2, color="#f39c12"); ax.fill(angp, vals, alpha=0.18, color="#f39c12")
    ax.set_xticks(ang); ax.set_xticklabels(axes, fontsize=12)
    ax.set_yticks([-0.5, 0.0, 0.5]); ax.set_yticklabels(["-0.5","0.00","+0.5"], fontsize=10)
    ax.set_ylim(-1.0, 1.0); ax.grid(True, alpha=0.35); ax.set_title(title, fontsize=15, pad=18)
    fig.tight_layout(); plt.savefig(out_png, bbox_inches="tight"); plt.close(fig)

def make_one(csv_path: pathlib.Path, title: str, out_png: pathlib.Path):
    de = load_de_table(csv_path)
    gs = gene_score_table(de)
    axes, vals, meta = pathway_panel_scores(gs)
    radar_plot(axes, vals, title, out_png)
    return dict(zip(axes, vals)), meta

# --- run for Metformin & Rapamycin ---
results = []
pairs = [
    (PROC / "GSE196343_metformin_de_results.csv",
     "Fig-2A: Metformin Pathway Panel (real data)",
     FIGS / "fig2a_radar_NMN_PO.png",
     "Metformin"),

    (PROC / "GSE102674_rapamycin_vs_TGFb1_de_results.csv",
     "Fig-2B: Rapamycin Pathway Panel (real data)",
     FIGS / "fig2b_radar_NMN_IV.png",
     "Rapamycin"),
]

for csv, title, outpng, label in pairs:
    if csv.exists():
        print(f"[run] {csv.name} -> {outpng.name}")
        vals, meta = make_one(csv, title, outpng)
        results.append((label, vals))
        print("  axes:", ", ".join(f"{k}:{vals[k]:+.3f}" for k in vals))
    else:
        print(f"[skip] {csv} 不存在")

if not results:
    raise SystemExit("没有找到可用的 DE 结果（先运行 fetch_*.py）")

# --- Fig-2C: consensus (平均) ---
# 以已生成的干预的每个轴分数逐轴平均
axes = list(results[0][1].keys())
cons_vals = {ax: float(np.mean([r[1][ax] for r in results])) for ax in axes}
radar_plot(axes, list(cons_vals.values()),
           "Fig-2C: Consensus Pathway Panel (Metformin + Rapamycin, real data)",
           FIGS / "fig2c_radar_NMN_TD.png")
print("[ok] Fig-2C consensus -> fig2c_radar_NMN_TD.png")

# 导出汇总表
rows = [{"intervention": lbl, **vals} for (lbl, vals) in results]
rows.append({"intervention": "Consensus(Metformin+Rapamycin)", **cons_vals})
pd.DataFrame(rows).to_csv(PROC / "pathway_panel_scores.csv", index=False)
print("[ok] scores table:", PROC / "pathway_panel_scores.csv")
