﻿# scripts/fetch_GSE102674_rapamycin.py  (v5: robust matrix metadata parsing + RAW→GSM mapping)
import os, io, re, gzip, tarfile, pathlib, sys, traceback
import pandas as pd
import numpy as np
import requests

BASE = pathlib.Path(__file__).resolve().parents[1]
RAW  = BASE / "data" / "raw"
PROC = BASE / "data" / "processed"
FIGS = BASE / "figures"
for d in (RAW, PROC, FIGS): d.mkdir(parents=True, exist_ok=True)

LOG = RAW / "GSE102674_run.log"
def log(msg):
    print(msg, flush=True)
    with open(LOG, "a", encoding="utf-8") as f: f.write(str(msg) + "\n")

# helpers import
sys.path.append(str(pathlib.Path(__file__).resolve().parent))
from analysis_common import differential_expression, volcano_plot, heatmap_top_genes

MATRIX_URL   = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE102nnn/GSE102674/matrix/GSE102674_series_matrix.txt.gz"
MATRIX_GZ    = RAW / "GSE102674_series_matrix.txt.gz"
RAW_TAR_URL  = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE102674&format=file"
RAW_TAR      = RAW / "GSE102674_RAW.tar"

MIN_PER_GROUP = 2

KEYS_TGF  = ["tgf", "tgfβ", "tgfb", "tgf-β", "tgf beta", "tgf-beta", "tgf-1", "tgfβ1", "tgfb1"]
KEYS_RAPA = ["rapa", "rapamycin", "sirolimus"]

def download(url, out_path, timeout=300):
    if out_path.exists(): log(f"[info] exists: {out_path}"); return
    log(f"[down] {url} -> {out_path}")
    r = requests.get(url, timeout=timeout); r.raise_for_status()
    out_path.write_bytes(r.content)
    log(f"[ok] saved: {out_path}  ({out_path.stat().st_size} bytes)")

def read_text_gz(path: pathlib.Path) -> str:
    with gzip.open(path, "rt", encoding="utf-8", errors="ignore") as fh:
        return fh.read()

def _stripq(s: str) -> str:
    s = s.strip()
    if len(s)>=2 and ((s[0]==s[-1]=="'") or (s[0]==s[-1]=='"')): s = s[1:-1]
    return s.strip()

def parse_series_meta_v5(text: str):
    """
    Robustly parse Series Matrix metadata in BOTH styles:
    - vector-per-field:   !Sample_title\t"t1"\t"t2"\t...
    - one-per-line:       !Sample_title = t1   (repeated per sample)
    Also get table header GSMs from the first row after !series_matrix_table_begin.
    Return: gsm_order (list), meta_by_gsm: {GSM: {"title":..., "chars":...}}
    """
    lines = text.splitlines()

    # --- table header GSMs ---
    begin = end = None
    for i, ln in enumerate(lines):
        s = ln.strip()
        if s == "!series_matrix_table_begin": begin = i+1
        if s == "!series_matrix_table_end":   end = i; break
    header_gsms = []
    if begin is not None and end is not None and end > begin:
        header_tokens = re.split(r"\t+", lines[begin].rstrip())
        if len(header_tokens) > 1:
            header_gsms = [t.strip().upper() for t in header_tokens[1:]]

    # --- extract fields ---
    titles_vec = []       # list[str] length = #samples (vector style)
    titles_list = []      # list[str] collected per line (one-per-line)
    geo_vec = []
    geo_list = []
    chars_rows = []       # list[list[str]]; multiple rows of characteristics (each row length may be #samples)

    for ln in lines:
        if ln.startswith("!Sample_title"):
            if "\t" in ln:
                parts = ln.split("\t")[1:]
                titles_vec = [_stripq(p) for p in parts]
            else:
                m = re.split(r"\s=\s", ln, maxsplit=1)
                if len(m)==2: titles_list.append(_stripq(m[1]))
        elif ln.startswith("!Sample_geo_accession"):
            if "\t" in ln:
                parts = ln.split("\t")[1:]
                geo_vec = [_stripq(p).upper() for p in parts]
            else:
                m = re.split(r"\s=\s", ln, maxsplit=1)
                if len(m)==2: geo_list.append(_stripq(m[1]).upper())
        elif ln.startswith("!Sample_characteristics_ch1"):
            if "\t" in ln:
                parts = ln.split("\t")[1:]
                chars_rows.append([_stripq(p) for p in parts])
            else:
                m = re.split(r"\s=\s", ln, maxsplit=1)
                if len(m)==2: chars_rows.append([_stripq(m[1])])

    # decide sample count
    n = 0
    for cand in (len(header_gsms), len(geo_vec), len(titles_vec), len(geo_list), len(titles_list)):
        if cand > n: n = cand
    if n == 0: raise RuntimeError("cannot detect sample count in series matrix")

    # unify GSM order
    if len(header_gsms) == n:
        gsm_order = header_gsms
    elif len(geo_vec) == n:
        gsm_order = geo_vec
    elif len(geo_list) == n:
        gsm_order = geo_list
    else:
        gsm_order = [f"GSM_{i+1}" for i in range(n)]

    # unify titles
    if len(titles_vec) == n:
        titles = titles_vec
    elif len(titles_list) == n:
        titles = titles_list
    else:
        titles = [""] * n

    # unify characteristics: fold multiple rows into per-sample strings
    chars = [""] * n
    if chars_rows:
        # vector-style rows (each row length==n)
        if all(len(r)==n for r in chars_rows):
            for row in chars_rows:
                for i in range(n):
                    if row[i]:
                        chars[i] = (chars[i] + " | " if chars[i] else "") + row[i]
        # one-per-line style (n rows each len==1)
        elif len(chars_rows)==n and all(len(r)==1 for r in chars_rows):
            for i in range(n):
                chars[i] = chars_rows[i][0]

    meta = {gsm: {"title": titles[i] if i < len(titles) else "",
                  "chars": chars[i]  if i < len(chars)  else ""} for i, gsm in enumerate(gsm_order)}
    return gsm_order, meta

def label_group(text: str) -> str:
    s = text.lower()
    has_tgf  = any(k in s for k in KEYS_TGF)
    has_rapa = any(k in s for k in KEYS_RAPA)
    if has_tgf and has_rapa: return "TGFb1_Rapa"
    if has_tgf and not has_rapa: return "TGFb1"
    return "Other"

def merge_from_raw_tar(tar_path: pathlib.Path) -> pd.DataFrame:
    with tarfile.open(tar_path, "r") as tf:
        members = [m for m in tf.getmembers()
                   if any(m.name.lower().endswith(ext) for ext in
                          [".txt",".tab",".tsv",".csv",".txt.gz",".tab.gz",".tsv.gz",".csv.gz"])]
        log(f"[raw] found {len(members)} candidate files")
        mats = []
        for m in members:
            try:
                fobj = tf.extractfile(m)
                if not fobj: continue
                raw = fobj.read()
                if m.name.lower().endswith(".gz"):
                    try: raw = gzip.decompress(raw)
                    except Exception: continue
                txt = raw.decode("utf-8", errors="ignore")
                df = None
                for sep in ("\t", ","):
                    try:
                        df = pd.read_csv(io.StringIO(txt), sep=sep, comment="#")
                        if not df.empty: break
                    except Exception: df = None
                if df is None or df.empty: continue
                # gene column
                cand = ["gene","Gene","symbol","Symbol","GeneSymbol","Gene Symbol","gene_id","GeneID","Gene ID","ID_REF","ID","Name","GeneName"]
                gene_col = next((c for c in cand if c in df.columns), df.columns[0])
                df = df.rename(columns={gene_col:"gene"}).dropna(subset=["gene"])
                num = df.drop(columns=["gene"], errors="ignore").apply(pd.to_numeric, errors="coerce")
                num = num.loc[:, [np.issubdtype(t, np.number) for t in num.dtypes]]
                if num.shape[1]==0: continue
                vals = num.mean(axis=1)
                out = pd.DataFrame({"gene": df["gene"].astype(str), m.name: vals}).set_index("gene")
                mats.append(out)
            except Exception:
                continue
    log(f"[raw] parsed {len(mats)} single-sample tables")
    if not mats: raise RuntimeError("RAW.tar contains no parsable tables")
    mat = pd.concat(mats, axis=1).sort_index()
    return mat

def extract_gsm(name: str) -> str | None:
    m = re.search(r"GSM\d+", name, flags=re.IGNORECASE)
    return m.group(0).upper() if m else None

def main():
    # clean log
    if LOG.exists(): LOG.unlink()
    log("=== GSE102674 (Rapamycin) v5 start ===")

    # 0) Parse metadata from Series Matrix (robust)
    download(MATRIX_URL, MATRIX_GZ)
    txt = read_text_gz(MATRIX_GZ)
    gsm_order, meta = parse_series_meta_v5(txt)
    md_out = RAW / "GSE102674_sample_metadata.tsv"
    pd.DataFrame([{"GSM": gsm, "title": meta[gsm].get("title",""), "characteristics": meta[gsm].get("chars","")} for gsm in gsm_order]).to_csv(md_out, sep="\t", index=False)
    log(f"[meta] parsed {len(gsm_order)} samples -> {md_out}")

    # 1) Merge RAW.tar into a matrix (genes x files)
    download(RAW_TAR_URL, RAW_TAR)
    expr_raw = merge_from_raw_tar(RAW_TAR)
    log(f"[raw] merged matrix shape: {expr_raw.shape}")

    # 2) Map RAW columns -> GSM -> metadata; relabel columns by GSM
    cols = list(expr_raw.columns)
    col2gsm = {c: (extract_gsm(c) or c) for c in cols}
    expr = expr_raw.copy()
    expr.columns = [col2gsm[c] for c in cols]

    # average duplicates mapping to same GSM
    if len(set(expr.columns)) < len(expr.columns):
        expr = expr.groupby(level=0, axis=1).mean()

    # 3) Build groups from metadata
    def label_for_gsm(gsm):
        md = meta.get(gsm, {"title": gsm, "chars": ""})
        return label_group((md.get("title","")+" "+md.get("chars","")).strip())

    groups = {c: label_for_gsm(c) for c in expr.columns}
    keep_cols = [c for c,g in groups.items() if g in ("TGFb1_Rapa","TGFb1")]
    grp_counts = pd.Series([groups[c] for c in keep_cols]).value_counts()
    log(f"[group] counts: {grp_counts.to_dict()}")

    # require replicates
    ok = (grp_counts.get("TGFb1_Rapa",0) >= MIN_PER_GROUP) and (grp_counts.get("TGFb1",0) >= MIN_PER_GROUP)
    if not ok:
        (RAW / "GSE102674_columns.txt").write_text("\n".join(cols), encoding="utf-8")
        log("[fatal] Not enough labeled columns. Inspect sample_metadata.tsv (titles/characteristics) and columns.txt, then extend KEYS_TGF/KEYS_RAPA.")
        return

    # 4) DE + figures
    expr2 = expr[keep_cols].copy()
    group_labels = [groups[c] for c in expr2.columns]
    log(f"[use] {len(expr2.columns)} columns: " + ", ".join(f"{c}:{groups[c]}" for c in expr2.columns))

    res = differential_expression(expr2, group_labels)
    out_csv = PROC / "GSE102674_rapamycin_vs_TGFb1_de_results.csv"
    res.to_csv(out_csv); log(f"[ok] DE: {out_csv}")

    volcano = FIGS / "GSE102674_rapamycin_volcano.png"
    heatmap = FIGS / "GSE102674_rapamycin_heatmap.png"
    volcano_plot(res, out_png=str(volcano), title="GSE102674 TGFb1+Rapamycin vs TGFb1")
    heatmap_top_genes(expr2, res, group_labels, out_png=str(heatmap),
                      title="GSE102674 top DE (TGFb1+Rapamycin vs TGFb1)")
    log(f"[ok] Figures:\n  {volcano}\n  {heatmap}")
    log("=== Done ===")

if __name__ == "__main__":
    try: main()
    except Exception:
        with open(RAW / "GSE102674_run.log", "a", encoding="utf-8") as f:
            f.write("\n[uncaught]\n" + traceback.format_exc() + "\n")
        raise
