# train_ce_ppi.py
# -*- coding: utf-8 -*-
import os, logging, random
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import torch
from sentence_transformers import CrossEncoder, InputExample
from sklearn.model_selection import GroupKFold
from torch.utils.data import DataLoader, WeightedRandomSampler
from config import *

# ---------- Numerical stability ----------
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # pick a truly free GPU (check with nvidia-smi; e.g., 0,1,2,3 if empty)
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_SDP_FORCE_FALLBACK"] = "1"
os.environ["PYTORCH_DISABLE_FAST_SDPA"] = "1"
os.environ["PYTORCH_FUSED_SDPA_DISABLE"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")

# ---------- Backbone BIO ----------
BACKBONE = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
DEVICE   = "cuda:0" if torch.cuda.is_available() else "cpu"

KEEP_CHARS = 2000
RANDOM_SEED = 42
BATCH_SIZE = 8
EPOCHS = 5  # change according to your needs          
LR = 1e-5
NEG_RATIO = 3        # up to 3 negatives per positive (to balance); None to use all negatives

# ---------- Utils ----------
def set_seed(seed=RANDOM_SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def norm(x): return "" if pd.isna(x) else str(x).strip()
def shrink(s, n=KEEP_CHARS): s = norm(s); return s[:n] if len(s) > n else s

def build_profile_text(r: pd.Series) -> str:
    name = norm(r.get("name",""))
    parts = [f"Protein: {name}."]
    mapping = [
        ("alias","Alias"), ("Annotation","Function/Desc"),
        ("GO_MF","GO_MF"), ("GO_BP","GO_BP"), ("GO_CC","GO_CC"),
        ("Domains","Domains"), ("Pathways","Pathways"), ("Notes","Subcellular")
    ]
    for col, label in mapping:
        v = norm(r.get(col,""))
        if v: parts.append(f"{label}: {v}")
    return shrink(" ".join(parts).strip() or f"Protein: {name}.")

def main():
    set_seed()

    # 1) Load annotations and candidates
    logging.info("Load annotations and candidates…")
    ann = pd.read_csv(ANN_PATH, sep="\t")
    ann["name"] = ann["name"].map(norm)
    ann_idx = ann.set_index("name", drop=False)

    df = pd.read_csv(CANDS_IN, sep="\t")
    for c in ("protein1","similar_protein_name","rediscovered_flag"):
        if c not in df.columns:
            raise RuntimeError(f"Missing col: {c}")
    df["protein1"] = df["protein1"].map(norm)
    df["similar_protein_name"] = df["similar_protein_name"].map(norm)

    # 2) Build textual profiles
    names = sorted(set(df["protein1"]).union(df["similar_protein_name"]))
    texts = {n: (build_profile_text(ann_idx.loc[n]) if n in ann_idx.index else f"Protein: {n}.") for n in names}

    # 3) Build training examples (with optional negative downsampling per p1)
    examples: List[InputExample] = []
    pos, neg = 0, 0
    for p1, g in df.groupby("protein1"):
        q = shrink(f"Protein of interest: {p1}. Context: {texts[p1]}")
        pos_rows = g[g["rediscovered_flag"]==True]
        neg_rows = g[g["rediscovered_flag"]!=True]

        # downsampling: at most NEG_RATIO * #positives
        if NEG_RATIO is not None and len(pos_rows) > 0 and len(neg_rows) > NEG_RATIO*len(pos_rows):
            neg_rows = neg_rows.sample(n=NEG_RATIO*len(pos_rows), random_state=RANDOM_SEED)

        for _, r in pos_rows.iterrows():
            ptxt = texts[r["similar_protein_name"]]
            examples.append(InputExample(texts=[q, ptxt], label=1.0)); pos += 1
        for _, r in neg_rows.iterrows():
            ptxt = texts[r["similar_protein_name"]]
            examples.append(InputExample(texts=[q, ptxt], label=0.0)); neg += 1

    if len(examples) == 0 or pos == 0 or neg == 0:
        raise RuntimeError(f"Not enough training pairs (total={len(examples)}, pos={pos}, neg={neg}).")

    logging.info(f"Training pairs: {len(examples)} (pos={pos}, neg={neg}, ratio={neg/max(1,pos):.2f})")

    # 4) Split by protein1 
    groups_q = [ex.texts[0] for ex in examples]  # query contains p1
    get_p1 = lambda q: q.split("Protein of interest: ")[1].split(".")[0]
    group_ids = [get_p1(q) for q in groups_q]

    gkf = GroupKFold(n_splits=5)
    train_idx, val_idx = next(gkf.split(np.zeros(len(examples)), groups=group_ids))
    train_samples = [examples[i] for i in train_idx]
    val_samples   = [examples[i] for i in val_idx]
    logging.info(f"Split sizes -> train: {len(train_samples)}, val: {len(val_samples)}")

    # 5) DataLoader (further balancing with weights optional)
    #   Here we use simple shuffle; for severe imbalance, use WeightedRandomSampler.
    train_dl = DataLoader(train_samples, batch_size=BATCH_SIZE, shuffle=True)
    val_dl   = DataLoader(val_samples,   batch_size=BATCH_SIZE, shuffle=False)

    # 6) CrossEncoder in FP32 (NO AMP) to avoid NaN
    model_kwargs = {"torch_dtype": torch.float32}   
    ce = CrossEncoder(
        BACKBONE,
        num_labels=1,               # BCEWithLogits
        device=DEVICE,
        max_length=512,
        model_kwargs=model_kwargs
    )

    # 7) Training
    logging.info(f"Start training on {DEVICE} (fp32)…")
    ce.fit(
        train_dataloader=train_dl,
        evaluator=None,
        epochs=EPOCHS,
        warmup_steps=max(10, int(0.1*len(train_samples)/BATCH_SIZE)),
        # explicitly use torch.optim.AdamW (no correct_bias)
        optimizer_class=torch.optim.AdamW,
        optimizer_params={"lr": LR, "eps": 1e-8, "weight_decay": 0.01},
        scheduler="WarmupLinear",
        show_progress_bar=True,
        use_amp=False,              # stay in FP32 to avoid NaN
        output_path=OUT_DIR,
    )

    # 8) Save in HF + ST format
    Path(OUT_DIR).mkdir(parents=True, exist_ok=True)
    ce.model.save_pretrained(OUT_DIR)
    ce.tokenizer.save_pretrained(OUT_DIR)
    ce.save(OUT_DIR)

    # 9) List saved files and probe model
    root_files = sorted(p.name for p in Path(OUT_DIR).glob("*"))
    best_files = sorted(p.name for p in Path(OUT_DIR, "best_model").glob("*")) if Path(OUT_DIR, "best_model").exists() else []
    logging.info(f"Save root: {root_files}")
    if best_files:
        logging.info(f"Save best_model/: {best_files}")

    # Test if the model produce numeric logits
    try:
        probe = ce.predict([["Protein of interest: TP53. Context: tumor suppressor.", "DNA-binding; apoptosis regulation."]])
        logging.info(f"sample logits: {probe}")
    except Exception as e:
        logging.warning(f"failed: {e}")

if __name__ == "__main__":
    main()