﻿# scripts/fetch_GSE102674_rapamycin.py  (v6: uses MINiML to map GSM→metadata; robust grouping)
import os, io, re, gzip, tarfile, pathlib, sys, traceback, xml.etree.ElementTree as ET
import pandas as pd
import numpy as np
import requests

BASE = pathlib.Path(__file__).resolve().parents[1]
RAW  = BASE / "data" / "raw"
PROC = BASE / "data" / "processed"
FIGS = BASE / "figures"
for d in (RAW, PROC, FIGS): d.mkdir(parents=True, exist_ok=True)

LOG = RAW / "GSE102674_run.log"
def log(msg):
    print(msg, flush=True)
    with open(LOG, "a", encoding="utf-8") as f: f.write(str(msg) + "\n")

# helpers import
sys.path.append(str(pathlib.Path(__file__).resolve().parent))
from analysis_common import differential_expression, volcano_plot, heatmap_top_genes

# ---- URLs ----
MATRIX_URL    = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE102nnn/GSE102674/matrix/GSE102674_series_matrix.txt.gz"
MATRIX_GZ     = RAW / "GSE102674_series_matrix.txt.gz"
RAW_TAR_URL   = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE102674&format=file"
RAW_TAR       = RAW / "GSE102674_RAW.tar"
MINIML_URL    = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE102nnn/GSE102674/miniml/GSE102674_family.xml.tgz"
MINIML_TGZ    = RAW / "GSE102674_family.xml.tgz"

MIN_PER_GROUP = 2

KEYS_TGF  = ["tgf", "tgfβ", "tgfb", "tgf-β", "tgf beta", "tgf-beta", "tgf-1", "tgfβ1", "tgfb1"]
KEYS_RAPA = ["rapa", "rapamycin", "sirolimus"]

def download(url, out_path, timeout=300):
    if out_path.exists(): log(f"[info] exists: {out_path}"); return
    log(f"[down] {url} -> {out_path}")
    r = requests.get(url, timeout=timeout); r.raise_for_status()
    out_path.write_bytes(r.content)
    log(f"[ok] saved: {out_path}  ({out_path.stat().st_size} bytes)")

def read_text_gz(path: pathlib.Path) -> str:
    with gzip.open(path, "rt", encoding="utf-8", errors="ignore") as fh:
        return fh.read()

def parse_series_meta(text: str):
    """Best-effort parse of series matrix metadata (vector + one-per-line)."""
    def _stripq(s: str) -> str:
        s = s.strip()
        if len(s)>=2 and ((s[0]==s[-1]=="'") or (s[0]==s[-1]=='"')): s = s[1:-1]
        return s.strip()

    lines = text.splitlines()
    titles_vec, titles_list = [], []
    geo_vec, geo_list = [], []
    chars_rows = []

    for ln in lines:
        if ln.startswith("!Sample_title"):
            if "\t" in ln: titles_vec = [_stripq(p) for p in ln.split("\t")[1:]]
            else:
                m = re.split(r"\s=\s", ln, maxsplit=1)
                if len(m)==2: titles_list.append(_stripq(m[1]))
        elif ln.startswith("!Sample_geo_accession"):
            if "\t" in ln: geo_vec = [_stripq(p).upper() for p in ln.split("\t")[1:]]
            else:
                m = re.split(r"\s=\s", ln, maxsplit=1)
                if len(m)==2: geo_list.append(_stripq(m[1]).upper())
        elif ln.startswith("!Sample_characteristics_ch1"):
            if "\t" in ln: chars_rows.append([_stripq(p) for p in ln.split("\t")[1:]])
            else:
                m = re.split(r"\s=\s", ln, maxsplit=1)
                if len(m)==2: chars_rows.append([_stripq(m[1])])

    n = max((len(geo_vec), len(titles_vec), len(geo_list), len(titles_list), 0))
    if n == 0: return [], {}
    if len(geo_vec) == n: gsm_order = geo_vec
    elif len(geo_list) == n: gsm_order = geo_list
    else: gsm_order = [f"GSM_{i+1}" for i in range(n)]

    if len(titles_vec) == n: titles = titles_vec
    elif len(titles_list) == n: titles = titles_list
    else: titles = [""]*n

    chars = [""]*n
    if chars_rows:
        if all(len(r)==n for r in chars_rows):
            for row in chars_rows:
                for i in range(n):
                    if row[i]: chars[i] = (chars[i]+" | " if chars[i] else "") + row[i]
        elif len(chars_rows)==n and all(len(r)==1 for r in chars_rows):
            for i in range(n): chars[i] = chars_rows[i][0]

    meta = {gsm: {"title": titles[i] if i < len(titles) else "",
                  "chars": chars[i]  if i < len(chars)  else ""} for i, gsm in enumerate(gsm_order)}
    return gsm_order, meta

def parse_miniml_meta(tgz_path: pathlib.Path):
    """Parse MINiML XML to get authoritative GSM→{title, chars}."""
    with tarfile.open(tgz_path, "r:gz") as tf:
        xml_member = None
        for m in tf.getmembers():
            if m.name.endswith(".xml"):
                xml_member = m; break
        if xml_member is None:
            raise RuntimeError("MINiML tarball has no XML member.")
        xml_bytes = tf.extractfile(xml_member).read()
    # parse XML ignoring namespaces
    def strip_ns(tag):
        return tag.split("}",1)[1] if "}" in tag else tag
    root = ET.fromstring(xml_bytes)
    meta = {}
    for sample in root.iter():
        if strip_ns(sample.tag) != "Sample": continue
        gsm = None; title = ""; chars_list = []
        for child in sample:
            tag = strip_ns(child.tag)
            if tag == "Accession": gsm = (child.text or "").strip().upper()
            elif tag == "Title":   title = (child.text or "").strip()
            elif tag == "Characteristics":
                txt = (child.text or "").strip()
                if txt: chars_list.append(txt)
        if gsm:
            meta[gsm] = {"title": title, "chars": " | ".join(chars_list)}
    return meta

def label_group(text: str) -> str:
    s = (text or "").lower()
    has_tgf  = any(k in s for k in KEYS_TGF)
    has_rapa = any(k in s for k in KEYS_RAPA)
    if has_tgf and has_rapa: return "TGFb1_Rapa"
    if has_tgf and not has_rapa: return "TGFb1"
    return "Other"

def merge_from_raw_tar(tar_path: pathlib.Path) -> pd.DataFrame:
    with tarfile.open(tar_path, "r") as tf:
        members = [m for m in tf.getmembers()
                   if any(m.name.lower().endswith(ext) for ext in
                          [".txt",".tab",".tsv",".csv",".txt.gz",".tab.gz",".tsv.gz",".csv.gz"])]
        log(f"[raw] found {len(members)} candidate files")
        mats = []
        for m in members:
            try:
                fobj = tf.extractfile(m)
                if not fobj: continue
                raw = fobj.read()
                if m.name.lower().endswith(".gz"):
                    try: raw = gzip.decompress(raw)
                    except Exception: continue
                txt = raw.decode("utf-8", errors="ignore")
                df = None
                for sep in ("\t", ","):
                    try:
                        df = pd.read_csv(io.StringIO(txt), sep=sep, comment="#")
                        if not df.empty: break
                    except Exception: df = None
                if df is None or df.empty: continue
                # gene column
                cand = ["gene","Gene","symbol","Symbol","GeneSymbol","Gene Symbol","gene_id","GeneID","Gene ID","ID_REF","ID","Name","GeneName"]
                gene_col = next((c for c in cand if c in df.columns), df.columns[0])
                df = df.rename(columns={gene_col:"gene"}).dropna(subset=["gene"])
                num = df.drop(columns=["gene"], errors="ignore").apply(pd.to_numeric, errors="coerce")
                num = num.loc[:, [np.issubdtype(t, np.number) for t in num.dtypes]]
                if num.shape[1]==0: continue
                vals = num.mean(axis=1)
                out = pd.DataFrame({"gene": df["gene"].astype(str), m.name: vals}).set_index("gene")
                mats.append(out)
            except Exception:
                continue
    log(f"[raw] parsed {len(mats)} single-sample tables")
    if not mats: raise RuntimeError("RAW.tar contains no parsable tables")
    mat = pd.concat(mats, axis=1).sort_index()
    return mat

def extract_gsm(name: str) -> str | None:
    m = re.search(r"GSM\d+", name, flags=re.IGNORECASE)
    return m.group(0).upper() if m else None

def main():
    if LOG.exists(): LOG.unlink()
    log("=== GSE102674 (Rapamycin) v6 start ===")

    # 0) Series meta (best-effort) + MINiML meta（权威）
    download(MATRIX_URL, MATRIX_GZ)
    txt = read_text_gz(MATRIX_GZ)
    gsm_order, meta_series = parse_series_meta(txt)
    log(f"[meta] series-matrix samples: {len(gsm_order)}")

    download(MINIML_URL, MINIML_TGZ)
    meta_miniml = parse_miniml_meta(MINIML_TGZ)
    log(f"[meta] miniml samples: {len(meta_miniml)}")

    # 1) Merge RAW
    download(RAW_TAR_URL, RAW_TAR)
    expr_raw = merge_from_raw_tar(RAW_TAR)
    log(f"[raw] merged matrix shape: {expr_raw.shape}")

    # 2) RAW列名 -> GSM（优先），否则保持原样；同时保留 first_raw_for_new 用作回退文本
    raw_cols = list(expr_raw.columns)
    new_cols = [extract_gsm(c) or c for c in raw_cols]
    first_raw_for_new = {}
    for rc, nc in zip(raw_cols, new_cols):
        first_raw_for_new.setdefault(nc, rc)

    expr = expr_raw.copy()
    expr.columns = new_cols

    # 折叠同一 GSM 的重复文件
    if len(set(new_cols)) < len(new_cols):
        expr = expr.groupby(level=0, axis=1).mean()

    # 3) 组别判定：优先 MINiML，其次 Series Matrix；若都无，则用原始文件名文本
    def text_for_col(col):
        md = meta_miniml.get(col) or meta_series.get(col)
        base = []
        if md: base += [md.get("title",""), md.get("chars","")]
        base.append(first_raw_for_new.get(col, col))
        return " ".join([b for b in base if b])

    groups = {c: label_group(text_for_col(c)) for c in expr.columns}
    keep_cols = [c for c,g in groups.items() if g in ("TGFb1_Rapa","TGFb1")]
    grp_counts = pd.Series([groups[c] for c in keep_cols]).value_counts()
    log(f"[group] counts: {grp_counts.to_dict()}")

    # 需要至少每组 MIN_PER_GROUP
    ok = (grp_counts.get("TGFb1_Rapa",0) >= MIN_PER_GROUP) and (grp_counts.get("TGFb1",0) >= MIN_PER_GROUP)
    if not ok:
        (RAW / "GSE102674_columns.txt").write_text("\n".join(raw_cols), encoding="utf-8")
        # 同时把一个对照表导出：列 -> 文本（你能肉眼看下）
        pd.DataFrame({"final_col": list(expr.columns),
                      "label": [groups[c] for c in expr.columns],
                      "text_used": [text_for_col(c) for c in expr.columns]}
                     ).to_csv(RAW / "GSE102674_label_text_map.tsv", sep="\t", index=False)
        log("[fatal] Not enough labeled columns. See sample_metadata (miniml), columns.txt, and label_text_map.tsv to extend KEYS_TGF/KEYS_RAPA.")
        return

    # 4) DE + 图
    expr2 = expr[keep_cols].copy()
    group_labels = [groups[c] for c in expr2.columns]
    log(f"[use] {len(expr2.columns)} columns: " + ", ".join(f"{c}:{groups[c]}" for c in expr2.columns))

    res = differential_expression(expr2, group_labels)
    out_csv = PROC / "GSE102674_rapamycin_vs_TGFb1_de_results.csv"
    res.to_csv(out_csv); log(f"[ok] DE: {out_csv}")

    volcano = FIGS / "GSE102674_rapamycin_volcano.png"
    heatmap = FIGS / "GSE102674_rapamycin_heatmap.png"
    volcano_plot(res, out_png=str(volcano), title="GSE102674 TGFb1+Rapamycin vs TGFb1")
    heatmap_top_genes(expr2, res, group_labels, out_png=str(heatmap),
                      title="GSE102674 top DE (TGFb1+Rapamycin vs TGFb1)")
    log(f"[ok] Figures:\n  {volcano}\n  {heatmap}")
    log("=== Done ===")

if __name__ == "__main__":
    try: main()
    except Exception:
        with open(RAW / "GSE102674_run.log", "a", encoding="utf-8") as f:
            f.write("\n[uncaught]\n" + traceback.format_exc() + "\n")
        raise
