# scripts/refresh_fig1_from_results.py (v3: auto/tight range + optional bars)
import argparse, pathlib, pandas as pd, numpy as np, matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm

ROOT = pathlib.Path(__file__).resolve().parents[1]
PROC = ROOT / "data" / "processed"
FIGS = ROOT / "figures"
FIGS.mkdir(exist_ok=True, parents=True)

def load_de(csv_path: pathlib.Path) -> float:
    df = pd.read_csv(csv_path)
    cols = {c.lower(): c for c in df.columns}
    for k in ["log2fc","logfc","lfc","log2_fc"]:
        if k in cols: lfc_col = cols[k]; break
    else: raise ValueError(f"{csv_path} 未找到 log2FC 列: {df.columns.tolist()}")
    p_col = next((k for k in ["fdr","padj","adj_p","adj.p","q","qval","p_adj","p.adj","pval","p"] if k in cols), None)
    lfc = pd.to_numeric(df[lfc_col], errors="coerce")
    if p_col:
        p = pd.to_numeric(df[p_col], errors="coerce").clip(lower=1e-300)
        w = (-np.log10(p)).clip(0, 5.0)
    else:
        w = pd.Series(1.0, index=df.index)
    score_raw = np.nanmean(np.sign(lfc) * np.minimum(np.abs(lfc), 1.0) * w)
    return float(np.tanh(score_raw / 2.0))  # 归一到 [-1,1]

def build_scores():
    rows = []
    met = PROC / "GSE196343_metformin_de_results.csv"
    rap = PROC / "GSE102674_rapamycin_vs_TGFb1_de_results.csv"
    if met.exists(): rows.append(("Metformin", load_de(met)))
    if rap.exists(): rows.append(("Rapamycin", load_de(rap)))
    if not rows: raise SystemExit("没有可用的 DE 结果，请先运行两个 fetch_*.py")
    labels, scores = zip(*rows)
    return list(labels), list(scores)

def save_heatmap(labels, scores, max_abs, out_png):
    M = np.array(scores).reshape(-1,1)
    norm = TwoSlopeNorm(vmin=-max_abs, vcenter=0.0, vmax=max_abs)
    fig, ax = plt.subplots(figsize=(6.2, 4.6), dpi=200)
    im = ax.imshow(M, aspect="auto", norm=norm, cmap="coolwarm")
    ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels)
    ax.set_xticks([0]); ax.set_xticklabels(["RNA-seq Composite Effect"])
    for i, v in enumerate(scores):
        ax.text(0, i, f"{v:+.3f}", ha="center", va="center", fontsize=12, color="black")
    for sp in ax.spines.values(): sp.set_visible(False)
    ax.set_title("Fig-1: Intervention Composite Score (from public RNA-seq)")
    cbar = plt.colorbar(im, ax=ax); cbar.set_label("Normalized effect (−1 … +1)")
    plt.tight_layout(); plt.savefig(out_png, bbox_inches="tight"); plt.close(fig)

def save_bars(labels, scores, max_abs, out_png):
    fig, ax = plt.subplots(figsize=(6.2, 3.6), dpi=200)
    cmap = plt.get_cmap("coolwarm")
    colors = [cmap((s + max_abs) / (2*max_abs)) for s in scores]
    ax.bar(labels, scores, color=colors)
    for i,s in enumerate(scores):
        ax.text(i, s + (0.02 if s>=0 else -0.02)*max_abs, f"{s:+.3f}", ha="center", va="bottom" if s>=0 else "top", fontsize=11)
    ax.axhline(0, color="k", lw=0.6, alpha=0.6)
    ax.set_ylabel("Normalized effect (−1 … +1)")
    ax.set_title("Intervention Composite Scores (bar view)")
    plt.tight_layout(); plt.savefig(out_png, bbox_inches="tight"); plt.close(fig)

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--range", type=float, default=None, help="固定色标半幅值，例如 0.02 表示 ±0.02")
    ap.add_argument("--tight", action="store_true", help="根据数据自动收紧范围（±1.2×最大绝对值，至少 ±0.01）")
    ap.add_argument("--bars", action="store_true", help="另外导出柱状图 fig1b_bars.png")
    args = ap.parse_args()

    labels, scores = build_scores()
    absmax = max(abs(s) for s in scores)

    if args.range is not None:
        max_abs = float(max(args.range, 0.005))
    elif args.tight:
        max_abs = max(absmax * 1.2, 0.01)
    else:
        max_abs = max(absmax, 0.05)  # 原默认：给一个不小于 ±0.05 的视觉范围

    # 主图（热图）
    out = FIGS / "fig1_matrix.png"
    save_heatmap(labels, scores, max_abs, out)

    # 可选：柱状图
    if args.bars:
        save_bars(labels, scores, max_abs, FIGS / "fig1b_bars.png")

    # 同步 CSV
    pd.DataFrame({"intervention": labels, "composite_score": scores}).to_csv(PROC / "intervention_composite_scores.csv", index=False)
    print("Wrote:", out, "range=±", round(max_abs,3))
    if args.bars: print("Wrote:", FIGS / "fig1b_bars.png")
