# -*- coding: utf-8 -*-
import logging, math
import pandas as pd
import numpy as np
from pathlib import Path
from collections import defaultdict
from config_base import *
import torch
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# If missing: pip install -q sentence-transformers
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from config import *

#export CUDA_VISIBLE_DEVICES=5
device = "cuda:0" if torch.cuda.is_available() else "cpu"   
print("using gpu ", device)
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")

# Make sure folders exist
os.makedirs(LLM_DIR, exist_ok=True)
os.makedirs(CLEAN_RESULTS_DIR, exist_ok=True)

# -------- Models --------
# change beyween BioBERT and BioRoBERTa
MODEL_NAME = "pritamdeka/S-BioBert-snli-multinli-stsb"
# pritamdeka/S-BioBert-snli-multinli-stsb -> Sentence BioBert
# pritamdeka/S-Biomed-Roberta-snli-multinli-stsb -> BioMedRoBERTa
BATCH_SIZE = 64

# -------- Helpers --------
def build_profile_text(row: pd.Series) -> str:
    """Compose a compact textual profile from the annotation table row."""
    parts = []
    # Prefer concise but informative fields
    alias = str(row.get("alias", "")).strip()
    annot = str(row.get("Annotation", "")).strip()
    gomf  = str(row.get("GO_MF", "")).strip()
    gobp  = str(row.get("GO_BP", "")).strip()
    gocc  = str(row.get("GO_CC", "")).strip()
    dom   = str(row.get("Domains", "")).strip()
    path  = str(row.get("Pathways", "")).strip()
    notes = str(row.get("Notes", "")).strip()

    if alias: parts.append(f"Alias: {alias}.")
    if annot: parts.append(f"Function/Desc: {annot}")
    if gomf:  parts.append(f"GO_MF: {gomf}")
    if gobp:  parts.append(f"GO_BP: {gobp}")
    if gocc:  parts.append(f"GO_CC: {gocc}")
    if dom:   parts.append(f"Domains: {dom}")
    if path:  parts.append(f"Pathways: {path}")
    if notes: parts.append(f"Subcellular: {notes}")

    # Fallback to name if everything is empty
    txt = " ".join(p for p in parts if p).strip()
    if not txt:
        name = str(row.get("name", "")).strip()
        txt = f"Protein {name}."
    return txt

def normalize_name(x):
    return str(x).strip()

def rank_within_groups(series_scores: pd.Series, group_ids: pd.Series, ascending=False):
    """Return 1-based ranks within each group id."""
    df = pd.DataFrame({"gid": group_ids, "score": series_scores})
    # higher score = better rank (1)
    df["rank"] = df.groupby("gid")["score"].rank(method="min", ascending=ascending)
    return df["rank"].astype(int)

# -------- Main pipeline --------
def main():
    logging.info("Loading annotations...")
    ann = pd.read_csv(ANN_PATH, sep="\t")
    # Index by 'name' (preferred key used in your sample)
    ann["name"] = ann["name"].map(normalize_name)
    ann_idx = ann.set_index("name", drop=False)

    logging.info("Loading candidates...")
    df = pd.read_csv(CANDS_IN, sep="\t")

    # Harmonize column names we need
    # protein1 (symbol/name), similar_protein_name (symbol/name)
    if "protein1" not in df.columns or "similar_protein_name" not in df.columns:
        raise RuntimeError("Required columns not found: 'protein1' and 'similar_protein_name'")

    df["protein1"] = df["protein1"].map(normalize_name)
    df["similar_protein_name"] = df["similar_protein_name"].map(normalize_name)

    # Collect all unique protein names we must embed
    all_names = sorted(set(df["protein1"]).union(set(df["similar_protein_name"])))
    logging.info(f"Unique proteins to embed: {len(all_names)}")

    # Build profile texts for each protein from annotations
    texts = {}
    for name in all_names:
        if name in ann_idx.index:
            texts[name] = build_profile_text(ann_idx.loc[name])
        else:
            # If missing in the annotation table, fall back to the name only
            texts[name] = f"Protein {name}."
    logging.info("Built profile texts.")

    # Encode all texts with Sentence-BioBERT (or BioMedRoBERTa)
    logging.info(f"Loading model: {MODEL_NAME}")
    model = SentenceTransformer(MODEL_NAME, device=device)
    corpus = [texts[n] for n in all_names]
    logging.info("Encoding profiles...")
    with torch.inference_mode():
        embeds = model.encode(
            corpus,
            batch_size=BATCH_SIZE,
            convert_to_numpy=True,
            show_progress_bar=True,
            normalize_embeddings=True  # normalize L2: cosine = dot
        )
    name2vec = {n: v for n, v in zip(all_names, embeds)}

    # Compute LLM_score (cosine similarity) for each pair (protein1, similar_protein_name)
    def cos(a, b):
        if a is None or b is None:
            return np.nan
        # sklearn cosine_similarity expects 2D
        return float(cosine_similarity(a.reshape(1, -1), b.reshape(1, -1))[0, 0])

    logging.info("Scoring pairs with cosine similarity of BioBERT embeddings...")
    scores = []
    for i, row in df.iterrows():
        p1 = row["protein1"]
        p2 = row["similar_protein_name"]
        v1 = name2vec.get(p1)
        v2 = name2vec.get(p2)
        s  = cos(v1, v2)
        scores.append(s)

    df["LLM_score"] = scores

    # Rank within each protein1 group (descending: higher similarity = better)
    df["rank_LLM_p1"] = rank_within_groups(df["LLM_score"], df["protein1"], ascending=False)

    # Compare with existing ranks to see movement 
    if "rank_cosine_p1" in df.columns:
        df["delta_rank_LLM_vs_cosine"] = df["rank_cosine_p1"] - df["rank_LLM_p1"]
    if "rank_IS_p1" in df.columns:
        df["delta_rank_LLM_vs_IS"]     = df["rank_IS_p1"]     - df["rank_LLM_p1"]
    if "rank_pdockq_p1" in df.columns:
        df["delta_rank_LLM_vs_pdockQ"] = df["rank_pdockq_p1"] - df["rank_LLM_p1"]

    # Save
    out_cols = list(df.columns)  # keep everything + new columns
    df.to_csv(OUT_TSV, sep="\t", index=False)
    logging.info(f"Wrote: {OUT_TSV}")
    # Quick per-p1 preview in logs
    for p1, g in df.groupby("protein1"):
        prev = g.sort_values("rank_LLM_p1").head(5)[["similar_protein_name","LLM_score","rank_LLM_p1"]]
        logging.info(f"\n[p1={p1}] Top by LLM:\n{prev.to_string(index=False)}")

if __name__ == "__main__":
    main()
