# test_ce_ppi.py
# -*- coding: utf-8 -*-
import os
# --- "safe" env (before torch) ---
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_SDP_FORCE_FALLBACK"] = "1"
os.environ["PYTORCH_DISABLE_FAST_SDPA"] = "1"
os.environ["PYTORCH_FUSED_SDPA_DISABLE"] = "1"

import logging, numpy as np, pandas as pd, torch
from pathlib import Path
from sentence_transformers import CrossEncoder
from config import *

logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")

# Ensure subfolders exist
os.makedirs(LLM_DIR, exist_ok=True)
os.makedirs(CLEAN_RESULTS_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

# -------- Config --------
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"   # respects CUDA_VISIBLE_DEVICES
MAX_LEN = 512
KEEP_CHARS = 2000
PRED_BATCH = 32   # batch size for ce.predict (reduce to 16/8 if needed)

# -------- Helpers --------
def shrink(s, n=KEEP_CHARS):
    s = "" if pd.isna(s) else str(s).strip()
    return s[:n] if len(s) > n else s

def norm_basic(x):
    return "" if pd.isna(x) else str(x).strip()

def norm_hard(s: str) -> str:
    import re, unicodedata
    if pd.isna(s): 
        return ""
    s = str(s)
    s = unicodedata.normalize("NFKC", s)
    s = s.replace("\u200b", "").replace("\ufeff", "").replace("\xa0", " ")
    s = s.strip()
    s = s.replace("'", "").replace('"', "").replace("’", "").replace("‘", "").replace("“", "").replace("”", "")
    s = s.strip("[](){}")
    s = re.sub(r"\s+", " ", s)
    return s

def build_profile_text(r: pd.Series) -> str:
    name = norm_basic(r.get("name",""))
    parts = [f"Protein: {name}."]
    mapping = [
        ("alias","Alias"), ("Annotation","Function/Desc"),
        ("GO_MF","GO_MF"), ("GO_BP","GO_BP"), ("GO_CC","GO_CC"),
        ("Domains","Domains"), ("Pathways","Pathways"), ("Notes","Subcellular")
    ]
    for col, label in mapping:
        val = norm_basic(r.get(col,""))
        if val: parts.append(f"{label}: {val}")
    txt = " ".join(parts).strip()
    return shrink(txt or f"Protein: {name}.")

def is_hf_model_dir(p: Path) -> bool:
    return (p / "config.json").exists() and (
        (p / "pytorch_model.bin").exists() or (p / "model.safetensors").exists()
    )

def find_model_root(model_dir: str) -> str:
    d = Path(model_dir)
    if not d.exists():
        raise RuntimeError(f"MODEL_DIR does not exist: {model_dir}")

    # typical preferences
    ordered = ["best_model", "0_CrossEncoder", "1_CrossEncoder", "ce", "model"]
    candidates = [d / name for name in ordered if (d / name).is_dir()]
    if is_hf_model_dir(d):
        candidates.append(d)
    candidates += [p for p in d.iterdir() if p.is_dir() and p not in candidates]

    for sd in candidates:
        if is_hf_model_dir(sd):
            return str(sd)

    # debug info
    entries = []
    for p in d.iterdir():
        if p.is_file(): entries.append(p.name)
        else:
            sub = ", ".join(q.name for q in p.iterdir())[:200]
            entries.append(f"{p.name}/ -> {sub}")
    raise RuntimeError("Could not find HF weights in {0}.\nContents:\n- {1}".format(model_dir, "\n- ".join(sorted(entries))))

def main():
    logging.info(f"torch.cuda.is_available={torch.cuda.is_available()} device={DEVICE}")

    # 1) Load model
    model_path = find_model_root(MODEL_DIR)
    logging.info(f"Loading CrossEncoder from: {model_path}")
    ce = CrossEncoder(
        model_path,
        device=DEVICE,
        max_length=MAX_LEN,
        automodel_args={"torch_dtype": torch.float32}
    )

    # probe
    probe = ce.predict([["Protein of interest: TP53. Context: tumor suppressor.", "DNA-binding; apoptosis regulation."]],
                       convert_to_numpy=True)
    logging.info(f"[PROBE] logits: {np.asarray(probe).ravel().tolist()}")

    # 2) Data loading
    logging.info("Loading annotations…")
    ann = pd.read_csv(ANN_PATH, sep="\t")
    ann["name"] = ann["name"].map(norm_basic)
    ann_idx = ann.set_index("name", drop=False)

    logging.info("Loading candidates…")
    df = pd.read_csv(CANDS_IN, sep="\t")
    for c in ("protein1", "similar_protein_name"):
        if c not in df.columns:
            raise RuntimeError(f"Missing column: {c}")

    df["protein1"]             = df["protein1"].map(norm_hard)
    df["similar_protein_name"] = df["similar_protein_name"].map(norm_hard)
    df = df[(df["protein1"]!="") & (df["similar_protein_name"]!="") & (df["similar_protein_name"].str.upper()!="NAN")]

    names = sorted(set(df["protein1"]).union(df["similar_protein_name"]))
    texts = {n: (build_profile_text(ann_idx.loc[n]) if n in ann_idx.index else f"Protein: {n}.") for n in names}

    # 3) Predictions per p1
    rows = []
    for p1, g in df.groupby("protein1", sort=False):
        g = g.copy()
        cand_names = list(g["similar_protein_name"].unique())
        if not cand_names:
            g["CE_raw"] = np.nan; g["CE_softmax_p1"] = np.nan
            rows.append(g); continue

        p1_text = texts.get(p1, f"Protein: {p1}.")
        query   = shrink(f"Protein of interest: {p1}. Context: {p1_text} Task: rank candidate proteins by likelihood of being a functional/interaction partner of {p1}.")

        pairs = [[query, texts.get(c, f"Protein: {c}.")] for c in cand_names]

        # predict in batches to save memory
        raw = ce.predict(pairs, batch_size=PRED_BATCH, convert_to_numpy=True)
        raw = np.asarray(raw)

        # reduce to 1D
        if raw.ndim == 1:
            raw_1d = raw.astype(float)
        elif raw.ndim >= 2:
            idx = 1 if raw.shape[1] > 1 else 0
            raw_1d = raw[:, idx].astype(float)
        else:
            raw_1d = np.full((len(pairs),), np.nan, dtype=float)

        finite = np.isfinite(raw_1d)
        if finite.any():
            logging.info(f"[{p1}] logits: min={np.nanmin(raw_1d):.6f} max={np.nanmax(raw_1d):.6f} std={np.nanstd(raw_1d):.6f} N={len(raw_1d)}")
        else:
            logging.warning(f"[{p1}] all logits are NaN (N={len(raw_1d)}).")

        # merge via map
        raw_map = {c: (float(s) if np.isfinite(s) else np.nan) for c, s in zip(cand_names, raw_1d)}
        gg = g.copy()
        gg["CE_raw"] = gg["similar_protein_name"].map(raw_map)

        # softmax for readability (only on finite values)
        def safe_softmax_grp(s: pd.Series):
            v = s.to_numpy(dtype=float)
            m = np.isfinite(v)
            if m.sum() == 0:
                return pd.Series(np.full_like(v, np.nan, dtype=float), index=s.index)
            vmax = np.nanmax(v[m])
            z = np.where(m, v - vmax, -1e9)
            ex = np.exp(z); den = np.sum(ex[m])
            sm = np.full_like(v, np.nan, dtype=float)
            sm[m] = ex[m] / max(den, 1e-12)
            return pd.Series(sm, index=s.index, dtype=float)

        gg["CE_softmax_p1"] = gg.groupby("protein1")["CE_raw"].transform(safe_softmax_grp)
        rows.append(gg)

    out = pd.concat(rows, ignore_index=True)

    # 4) ranking 
    def rank_grp(gr: pd.DataFrame) -> pd.DataFrame:
        if gr["CE_raw"].isna().all():
            gr = gr.copy()
            gr["rank_CE_p1"] = range(1, len(gr)+1)
            return gr
        cr = pd.to_numeric(gr["CE_raw"], errors="coerce").replace([np.inf, -np.inf], np.nan).fillna(float("-inf"))
        if np.isclose(cr.max(), cr.min()):
            if "rank_cosine_p1" in gr.columns:
                r = gr["rank_cosine_p1"].rank(method="min", ascending=True).astype(int)
                return gr.assign(rank_CE_p1=r)
            gr = gr.sort_values("similar_protein_name", kind="mergesort")
            gr["rank_CE_p1"] = range(1, len(gr)+1)
            return gr
        r = cr.rank(method="min", ascending=False).astype(int)
        return gr.assign(rank_CE_p1=r)

    out = out.groupby("protein1", group_keys=False).apply(rank_grp)

    # deltas
    if "rank_cosine_p1" in out.columns:
        out["delta_rank_CE_vs_cosine"] = out["rank_cosine_p1"] - out["rank_CE_p1"]
    if "rank_IS_p1" in out.columns:
        out["delta_rank_CE_vs_IS"] = out["rank_IS_p1"] - out["rank_CE_p1"]
    if "rank_pdockq_p1" in out.columns:
        out["delta_rank_CE_vs_pdockQ"] = out["rank_pdockq_p1"] - out["rank_CE_p1"]

    out.to_csv(OUT_TSV, sep="\t", index=False)
    logging.info(f"[OK] wrote {OUT_TSV}")

if __name__ == "__main__":
    main()