# -*- coding: utf-8 -*-
import os
# Avoid TF/Flax dependencies and noisy tokenizer parallelism
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import logging
import torch
import pandas as pd
import numpy as np
from sentence_transformers import CrossEncoder
from scipy.special import expit as sigmoid
from scipy.special import softmax
from config import *

logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")

# -------- Cross-encoder --------
CE_MODEL = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# -------- Helpers --------
KEEP_CHARS = 2000  # keep text short to respect seq len (~512 tokens)
def shrink_text(txt: str, max_chars: int = KEEP_CHARS) -> str:
    txt = (txt or "").strip()
    return txt[:max_chars] if len(txt) > max_chars else txt

def build_profile_text(row: pd.Series) -> str:
    parts = []
    alias = str(row.get("alias", "")).strip()
    annot = str(row.get("Annotation", "")).strip()
    gomf  = str(row.get("GO_MF", "")).strip()
    gobp  = str(row.get("GO_BP", "")).strip()
    gocc  = str(row.get("GO_CC", "")).strip()
    dom   = str(row.get("Domains", "")).strip()
    path  = str(row.get("Pathways", "")).strip()
    notes = str(row.get("Notes", "")).strip()

    if alias: parts.append(f"Alias: {alias}.")
    if annot: parts.append(f"Function/Desc: {annot}")
    if gomf:  parts.append(f"GO_MF: {gomf}")
    if gobp:  parts.append(f"GO_BP: {gobp}")
    if gocc:  parts.append(f"GO_CC: {gocc}")
    if dom:   parts.append(f"Domains: {dom}")
    if path:  parts.append(f"Pathways: {path}")
    if notes: parts.append(f"Subcellular: {notes}")

    txt = " ".join(p for p in parts if p).strip()
    if not txt:
        name = str(row.get("name", "")).strip()
        txt = f"Protein {name}."
    return shrink_text(txt)

def normalize_name(x): return str(x).strip()

def build_p1_query(p1_row_text: str, p1_name: str) -> str:
    # prompt for ranking candidates as interaction partners
    return (
        f"Protein of interest: {p1_name}. "
        f"Context: {p1_row_text} "
        f"Task: rank candidate proteins by likelihood of being a functional/interaction partner of {p1_name}."
    )

def normalize_group(scores_1d, mode="softmax", temperature=0.5):
    s = np.asarray(scores_1d, dtype=float).reshape(-1)
    if mode == "softmax":
        # smaller temperature => more aggressive spread
        return softmax(s / max(1e-6, temperature))
    elif mode == "minmax":
        lo, hi = float(np.min(s)), float(np.max(s))
        return (s - lo) / (hi - lo + 1e-9)
    elif mode == "zscore":
        mu, sd = float(np.mean(s)), float(np.std(s) + 1e-9)
        return (s - mu) / sd
    return s

# -------- Main --------
def main():
    logging.info("Loading annotations...")
    ann = pd.read_csv(ANN_PATH, sep="\t")
    ann["name"] = ann["name"].map(normalize_name)
    ann_idx = ann.set_index("name", drop=False)

    logging.info("Loading candidates...")
    df = pd.read_csv(CANDS_IN, sep="\t")
    if "protein1" not in df.columns or "similar_protein_name" not in df.columns:
        raise RuntimeError("Missing required columns: 'protein1', 'similar_protein_name'")

    df["protein1"] = df["protein1"].map(normalize_name)
    df["similar_protein_name"] = df["similar_protein_name"].map(normalize_name)

    # Build textual profiles
    all_names = sorted(set(df["protein1"]).union(set(df["similar_protein_name"])))
    texts = {}
    for name in all_names:
        if name in ann_idx.index:
            texts[name] = build_profile_text(ann_idx.loc[name])
        else:
            texts[name] = f"Protein {name}."

    # Load the CE ONCE
    logging.info(f"Loading CrossEncoder: {CE_MODEL}")
    ce = CrossEncoder(CE_MODEL, device=device)

    # Rerank per p1
    rows = []
    for p1, g in df.groupby("protein1"):
        cand_names = list(g["similar_protein_name"].unique())
        cand_map   = {c: texts.get(c, f"Protein {c}.") for c in cand_names}

        p1_text = texts.get(p1, f"Protein {p1}.")
        query   = shrink_text(build_p1_query(p1_text, p1))

        pairs = [[query, cand_map[c]] for c in cand_names]
        pred  = ce.predict(pairs)

        # → numpy
        raw = pred.detach().cpu().numpy() if hasattr(pred, "detach") else np.array(pred)

        # binary vs multiclass
        if raw.ndim == 1 or (raw.ndim == 2 and raw.shape[1] == 1):
            raw_1d = raw.reshape(-1)
        else:
            pos_idx = min(1, raw.shape[1]-1)
            raw_1d = raw[:, pos_idx]

        # rank 
        mn, mx = float(raw_1d.min()), float(raw_1d.max())
        ce_score_norm = (raw_1d - mn) / (mx - mn + 1e-9)

        # use a MAP <candidate_name -> score>
        raw_map   = {c: float(s) for c, s in zip(cand_names, raw_1d)}
        score_map = {c: float(s) for c, s in zip(cand_names, ce_score_norm)}

        gg = g.copy()
        gg["CE_raw"]   = gg["similar_protein_name"].map(raw_map)
        gg["CE_score"] = gg["similar_protein_name"].map(score_map)

        rows.append(gg)

    df_out = pd.concat(rows, ignore_index=True)
    df_out["rank_CE_p1"] = df_out.groupby("protein1")["CE_raw"].rank(method="min", ascending=False).astype(int)

    # deltas to measure difference vs other methods
    if "rank_cosine_p1" in df_out.columns:
        df_out["delta_rank_CE_vs_cosine"] = df_out["rank_cosine_p1"] - df_out["rank_CE_p1"]
    if "rank_IS_p1" in df_out.columns:
        df_out["delta_rank_CE_vs_IS"] = df_out["rank_IS_p1"] - df_out["rank_CE_p1"]
    if "rank_pdockq_p1" in df_out.columns:
        df_out["delta_rank_CE_vs_pdockQ"] = df_out["rank_pdockq_p1"] - df_out["rank_CE_p1"]

    df_out.to_csv(OUT_TSV, sep="\t", index=False)
    logging.info(f"Wrote: {OUT_TSV}")

if __name__ == "__main__":
    main()