# scripts/make_fig3to5.py  (v2: paper-ready styling + cleaner layout)
import pathlib, re, numpy as np, pandas as pd, matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter

ROOT = pathlib.Path(__file__).resolve().parents[1]
PROC = ROOT / "data" / "processed"
FIGS = ROOT / "figures"
FIGS.mkdir(parents=True, exist_ok=True)

# ---------- styling ----------
def apply_style():
    plt.rcParams.update({
        "figure.dpi": 300,
        "savefig.dpi": 300,
        "figure.figsize": (4.2, 3.2),      # 单图更紧凑，适配两栏
        "axes.spines.top": False,
        "axes.spines.right": False,
        "axes.grid": True,
        "grid.alpha": 0.18,
        "axes.titlesize": 12,
        "axes.labelsize": 10,
        "xtick.labelsize": 9,
        "ytick.labelsize": 9,
        "legend.fontsize": 9,
        "font.family": "serif",
    })

def _pick(cols, cands):
    m = {c.lower(): c for c in cols}
    for k in cands:
        if k in m: return m[k]
    return None

def load_de(csv_path: pathlib.Path) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    gene = _pick(df.columns, ["gene","genes","symbol","gene_symbol","id","name"]) or df.columns[0]
    lfc  = _pick(df.columns, ["log2fc","logfc","lfc","log2_fc"])
    if lfc is None: raise ValueError(f"{csv_path.name}: no logFC column")
    pcol = _pick(df.columns, ["fdr","padj","adj_p","adj.p","q","qval","p_adj","p.adj","pval","p"])
    out = pd.DataFrame({
        "gene": df[gene].astype(str).str.strip().str.upper(),
        "lfc":  pd.to_numeric(df[lfc], errors="coerce"),
        "p":    pd.to_numeric(df[pcol], errors="coerce") if pcol else np.nan
    }).dropna(subset=["gene"])
    return out

def composite_score(de: pd.DataFrame, fc_cap=1.0, pcap=5.0, frac=1.0, use_p=True) -> float:
    dx = de.copy()
    # 只取 |lfc| 前 frac 比例
    if 0 < frac < 1.0:
        k = max(10, int(len(dx) * frac))
        dx = dx.reindex(dx["lfc"].abs().sort_values(ascending=False).index)[:k]
    lfc = dx["lfc"].clip(-3, 3)
    if use_p and "p" in dx:
        w = (-np.log10(dx["p"].clip(lower=1e-300))).clip(0, pcap)
    else:
        w = pd.Series(1.0, index=dx.index)
    s = np.sign(lfc) * np.minimum(np.abs(lfc), fc_cap) * w
    return float(np.tanh(np.nanmean(s) / 2.0))      # [-1,1]

# collect DE
pairs = []
met = PROC / "GSE196343_metformin_de_results.csv"
rap = PROC / "GSE102674_rapamycin_vs_TGFb1_de_results.csv"
if met.exists(): pairs.append(("Metformin", load_de(met)))
if rap.exists(): pairs.append(("Rapamycin", load_de(rap)))
if not pairs: raise SystemExit("No DE results found. Run the fetch_* scripts first.")

apply_style()

# ---------- Fig-3: Sensitivity ----------
def fig3(path):
    base = np.mean([composite_score(de) for _, de in pairs])
    delta = [
        ("Bioavailability\n(|log2FC| cap 0.5)",  np.mean([composite_score(de, fc_cap=0.5) for _, de in pairs]) - base),
        ("tmax\n(p-weight cap 2)",               np.mean([composite_score(de, pcap=2.0)  for _, de in pairs]) - base),
        ("Exposure window\ntop-50% genes",        np.mean([composite_score(de, frac=0.5)  for _, de in pairs]) - base),
    ]
    labels, vals = zip(*delta)
    fig, ax = plt.subplots()
    y = np.arange(len(labels))
    bars = ax.barh(y, vals)
    ax.set_yticks(y); ax.set_yticklabels(labels)
    ax.axvline(0, color="k", lw=0.7, alpha=0.6)
    ax.set_xlabel("Change in composite effect (Δ, −1…+1)")
    ax.set_title("A) Sensitivity to scoring priors")
    # 数字标注贴边排版
    for i, (b, v) in enumerate(zip(bars, vals)):
        ax.text(b.get_width() + (0.004 if v>=0 else -0.004), b.get_y()+b.get_height()/2,
                f"{v:+.3f}", va="center", ha="left" if v>=0 else "right")
    fig.tight_layout(pad=0.5); fig.savefig(path, bbox_inches="tight"); plt.close(fig)

# ---------- Fig-4: Ablation PK/PD ----------
def fig4(path):
    full  = [composite_score(de, use_p=True)  for _, de in pairs]
    nop   = [composite_score(de, use_p=False) for _, de in pairs]
    labels = [name for name,_ in pairs]
    x = np.arange(len(labels)); w = 0.38
    fig, ax = plt.subplots()
    a = ax.bar(x - w/2, full, width=w, label="Full")
    b = ax.bar(x + w/2, nop,  width=w, label="No PK/PD")
    ax.axhline(0, color="k", lw=0.7, alpha=0.6)
    ax.set_xticks(x); ax.set_xticklabels(labels)
    ax.set_ylabel("Composite effect (−1…+1)")
    ax.set_title("B) Ablation — remove PK/PD prior")
    ax.legend(frameon=False, ncol=2, loc="upper left")
    for i,(v1,v2) in enumerate(zip(full,nop)):
        ax.text(i - w/2, v1 + 0.012, f"{v1:+.3f}", ha="center", va="bottom", fontsize=8)
        ax.text(i + w/2, v2 + 0.012, f"{v2:+.3f}", ha="center", va="bottom", fontsize=8)
    fig.tight_layout(pad=0.5); fig.savefig(path, bbox_inches="tight"); plt.close(fig)

# ---------- Fig-5: Ablation Causal ----------
def fig5(path):
    try:
        import gseapy as gp
        HAVE = True
    except Exception:
        HAVE = False
    fig, ax = plt.subplots()
    if not HAVE:
        ax.text(0.5, 0.5, "Install gseapy to enable\ncausal ablation", ha="center", va="center")
        ax.axis("off"); fig.savefig(path, bbox_inches="tight"); plt.close(fig); return

    # 复用 make_pathway_radars.py 的定义（简化版）
    # 简化：FULL=多库联合；KEGG_ONLY=只用 KEGG
    LIBS_FULL = ["KEGG_2021_Human","Reactome_2022","GO_Biological_Process_2023","WikiPathways_2021_Human"]
    LIBS_KEGG = ["KEGG_2021_Human"]
    AXIS_PAT = {
        "mTOR": r"\bmTOR\b|PI3K.*MTOR", "AMPK": r"\bAMPK signaling\b|\bAMP-activated",
        "SIRT": r"\bSirtu(in|ins)\b|\bSIRT\b", "FOXO": r"\bFOXO",
        "Autophagy": r"\bAutophagy\b", "MitoBio": r"mitochondrial biogenesis|mitochondrial gene expression",
        "Inflammation": r"\binflammatory response\b|regulation of inflammatory", "Senescence": r"\bsenescence\b",
    }
    _cache = {}
    def getlib(n):
        if n in _cache: return _cache[n]
        try: d = gp.get_library(name=n)
        except Exception: d = {}
        _cache[n] = d; return d
    import re as _re
    def mean_abs_axis(mode):
        libs = LIBS_FULL if mode=="FULL" else LIBS_KEGG
        # 聚合所有 term→基因集（去重）
        allsets = {}
        for lib in libs:
            d = getlib(lib)
            for term, glist in d.items():
                allsets.setdefault(term, set()).update(g.upper() for g in glist)
        vals = []
        for _, de in pairs:
            lfc = de["lfc"].clip(-3,3); w = (-np.log10(de["p"].clip(lower=1e-300))).clip(0,5).fillna(1.0)
            s = np.sign(lfc) * np.minimum(np.abs(lfc),1.0) * w
            gseries = pd.DataFrame({"gene": de["gene"], "score": s}).groupby("gene", as_index=False)["score"].mean().set_index("gene")["score"]
            base = np.nanpercentile(np.abs(gseries.values), 95) or 1.0
            axis_scores = []
            for axis, pat in AXIS_PAT.items():
                rx = _re.compile(pat, flags=_re.IGNORECASE)
                genes = set().union(*(g for t,g in allsets.items() if rx.search(t)))
                inter = genes.intersection(set(gseries.index))
                if not inter: axis_scores.append(0.0)
                else:
                    v = gseries.loc[list(inter)].mean()
                    axis_scores.append(float(np.tanh(v/(2.0*base))))
            vals.append(np.mean(np.abs(axis_scores)))
        return float(np.mean(vals))
    full = mean_abs_axis("FULL"); kegg = mean_abs_axis("KEGG")
    bars = ax.bar(["Full (multi-libs)", "No Causal (KEGG only)"], [full,kegg])
    for i,v in enumerate([full,kegg]): ax.text(i, v + 0.01, f"{v:.3f}", ha="center", va="bottom", fontsize=9)
    ax.set_ylabel("Mean |pathway score| across 8 axes")
    ax.set_title("C) Ablation — remove causal propagation")
    fig.tight_layout(pad=0.5); fig.savefig(path, bbox_inches="tight"); plt.close(fig)

if __name__ == "__main__":
    fig3(FIGS / "fig3_sensitivity.png")
    fig4(FIGS / "fig4_ablation_pkpd.png")
    fig5(FIGS / "fig5_ablation_causal.png")
    print("Wrote:",
          FIGS / "fig3_sensitivity.png",
          FIGS / "fig4_ablation_pkpd.png",
          FIGS / "fig5_ablation_causal.png")
