# -*- coding: utf-8 -*-
"""
Text-only multi-view re-ranking):
- tfidf_score         : cosine similarity on TF-IDF profiles
- overlap_score       : token Jaccard
- location_score      : Jaccard on subcellular/chromosome cues
- keyterm_score       : Jaccard on curated biomedical key terms

Input:
  ANN_PATH  : TSV with annotations ('name' + Annotation/GO_*...)
  CANDS_IN  : TSV with 'protein1' and 'similar_protein_name' (+ any existing rank_* columns)

Output:
  OUT_TSV   : same TSV + new columns score/rank/delta for each view
"""

import logging, re, os
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from config import *

# Ensure subfolders exist
os.makedirs(LLM_DIR, exist_ok=True)
os.makedirs(CLEAN_RESULTS_DIR, exist_ok=True)
os.makedirs(SEMANTIC_DIR, exist_ok=True)

logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")

# -------- Heuristics / helpers --------
BASIC_STOP = {
    "the","a","an","and","or","of","in","on","for","to","by","with","from","as","at","is","are",
    "this","that","these","those","it","its","their","his","her","be","may","can","plays","play",
    "protein","gene","subunit","member","family","domain","large","small","alpha","beta","gamma",
    "delta","like","similar","involved","involving","interacts","interaction","complex",
    "una","un","il","lo","la","i","gli","le","che","di","da","per","con","su","nel","nella",
}

LOCATION_TERMS = [
    "nucleus","nuclear","nucleolus","nucleoplasm",
    "cytosol","cytoplasm","cytoplasmic",
    "mitochondrion","mitochondria","mitochondrial","mitochondrial matrix",
    "inner mitochondrial membrane","outer mitochondrial membrane",
    "endoplasmic reticulum","er","rough er","smooth er",
    "golgi","golgi membrane","golgi apparatus",
    "lysosome","peroxisome","endosome",
    "plasma membrane","cell membrane","membrane","cell surface",
    "extracellular","secreted","vesicle","exosome",
    "ribosome","proteasome","chromatin","centrosome","microtubule","actin","cytoskeleton",
    "perinuclear region",
]

CHR_PATTERNS = [
    r"\bchr[0-9xyXYmtMT]+\b",
    r"\bchromosome\s+[0-9XYMTxy mt]+(?:[pq][0-9\.]+)?\b",
]

KEY_TERMS = [
    # role/function
    "receptor","ligand","kinase","phosphatase","ubiquitin","chaperone","transporter",
    "transcription","translation","mrna","splicing","apoptosis","autophagy","stress",
    "immune","inflammatory","antiviral","metabolic","mitochondrial","oxidative",
    # structural/sequence hints
    "signal peptide","transmembrane","tm domain","coiled-coil","zinc finger","leucine zipper",
]

def normalize_name(x): return str(x).strip()

def build_profile_text(row: pd.Series) -> str:
    parts = []
    for col in ["alias","Annotation","Function","Description","Summary","GO_MF","GO_BP","GO_CC",
                "Domains","Pathways","Notes","comment","text"]:
        if col in row and pd.notna(row[col]) and str(row[col]).strip():
            label = col.replace("_", " ")
            parts.append(f"{label}: {str(row[col]).strip()}.")
    if not parts:
        parts = [f"Protein {str(row.get('name','')).strip()}."]
    return " ".join(parts)

def tokenize_info(text: str) -> set:
    if not isinstance(text, str) or not text.strip():
        return set()
    t = text.lower()
    t = re.sub(r"[^a-z0-9\s]", " ", t)
    toks = [w for w in t.split() if w and w not in BASIC_STOP and not w.isdigit() and len(w) > 2]
    return set(toks)

def extract_locations(text: str) -> set:
    locs = set()
    if not isinstance(text, str) or not text.strip():
        return locs
    tl = text.lower()
    for term in LOCATION_TERMS:
        if term in tl: locs.add(term)
    for pat in CHR_PATTERNS:
        for m in re.finditer(pat, tl):
            s = m.group(0)
            s = re.sub(r"chromosome\s+([0-9xyXYmtMT]+)", r"chr\1", s, flags=re.I).lower()
            locs.add(s)
    return locs

def keyterm_hits(text: str) -> set:
    hits = set()
    if not isinstance(text, str) or not text.strip():
        return hits
    tl = text.lower()
    for k in KEY_TERMS:
        if k in tl: hits.add(k)
    return hits

def jaccard(a: set, b: set) -> float:
    if not a and not b: return np.nan
    u = len(a | b)
    return len(a & b) / u if u else np.nan

def rank_within_groups(scores: pd.Series, gids: pd.Series, ascending=False) -> pd.Series:
    df = pd.DataFrame({"gid": gids, "score": scores})
    df["rank"] = df.groupby("gid")["score"].rank(method="min", ascending=ascending)
    return df["rank"].astype(int)

def add_delta(df: pd.DataFrame, new_rank_col: str):
    # compare new rank vs existing ranks (if any)
    for base in ["rank_cosine_p1","rank_IS_p1","rank_pdockq_p1","rank_hybrid_p1","original_rank","rank_p1"]:
        if base in df.columns:
            df[f"delta_{new_rank_col}_vs_{base}"] = df[base] - df[new_rank_col]
    return df

def main():
    logging.info("Loading annotations...")
    ann = pd.read_csv(ANN_PATH, sep="\t", low_memory=False)
    if "name" not in ann.columns:
        raise RuntimeError("Annotations must contain a 'name' column as key.")
    ann["name"] = ann["name"].map(normalize_name)
    ann_idx = ann.set_index("name", drop=False)

    logging.info("Loading candidates...")
    df = pd.read_csv(CANDS_IN, sep="\t", low_memory=False)
    if "protein1" not in df.columns or "similar_protein_name" not in df.columns:
        raise RuntimeError("Candidates must contain 'protein1' and 'similar_protein_name'.")

    df["protein1"] = df["protein1"].map(normalize_name)
    df["similar_protein_name"] = df["similar_protein_name"].map(normalize_name)

    # ----- Build per-protein texts & sets -----
    names = sorted(set(df["protein1"]).union(df["similar_protein_name"]))
    texts, tokens, locs, keys = {}, {}, {}, {}
    logging.info(f"Preparing profiles for {len(names)} proteins...")
    for n in names:
        if n in ann_idx.index:
            row = ann_idx.loc[n]
            txt = build_profile_text(row)
        else:
            txt = f"Protein {n}."
        texts[n]  = txt
        tokens[n] = tokenize_info(txt)
        locs[n]   = extract_locations(txt)
        keys[n]   = keyterm_hits(txt)

    # ----- View 1: TF-IDF cosine on profiles -----
    logging.info("Computing TF-IDF cosine similarities...")
    corpus = [texts[n] for n in names]
    vec = TfidfVectorizer(
        lowercase=True,
        token_pattern=r"[a-zA-Z0-9\-]{3,}",
        ngram_range=(1,2),         # 1-2 gram to get "cell membrane"-like sentences
        min_df=1,
        max_df=0.98
    )
    X = vec.fit_transform(corpus)  # shape: (N, V)
    name2row = {n: i for i, n in enumerate(names)}

    def tfidf_cos(p1, p2):
        i, j = name2row.get(p1), name2row.get(p2)
        if i is None or j is None: return np.nan
        sim = cosine_similarity(X[i], X[j])
        sim = sim.A if hasattr(sim, "A") else sim  
        return float(sim[0, 0])

    df["tfidf_score"] = [
        tfidf_cos(p1, p2) for p1, p2 in zip(df["protein1"], df["similar_protein_name"])
    ]
    df["tfidf_score"] = df["tfidf_score"].fillna(0.0)
    df["rank_tfidf_p1"] = rank_within_groups(df["tfidf_score"], df["protein1"], ascending=False)
    df = add_delta(df, "rank_tfidf_p1")

    # ----- View 2: overlap_words (token Jaccard) -----
    logging.info("Scoring overlap_words (Jaccard tokens)...")
    df["overlap_score"] = [
        jaccard(tokens.get(p1, set()), tokens.get(p2, set()))
        for p1, p2 in zip(df["protein1"], df["similar_protein_name"])
    ]
    df["overlap_score"] = df["overlap_score"].fillna(0.0)
    df["rank_overlap_p1"] = rank_within_groups(df["overlap_score"], df["protein1"], ascending=False)
    df = add_delta(df, "rank_overlap_p1")

    # ----- View 3: location (subcellular/chromosome Jaccard) -----
    logging.info("Scoring location (Jaccard on location cues)...")
    df["location_score"] = [
        jaccard(locs.get(p1, set()), locs.get(p2, set()))
        for p1, p2 in zip(df["protein1"], df["similar_protein_name"])
    ]
    df["location_score"] = df["location_score"].fillna(0.0)
    df["rank_location_p1"] = rank_within_groups(df["location_score"], df["protein1"], ascending=False)
    df = add_delta(df, "rank_location_p1")

    # ----- View 4: key_terms (keyword Jaccard) -----
    logging.info("Scoring key_terms (Jaccard on curated terms)...")
    df["keyterm_score"] = [
        jaccard(keys.get(p1, set()), keys.get(p2, set()))
        for p1, p2 in zip(df["protein1"], df["similar_protein_name"])
    ]
    df["keyterm_score"] = df["keyterm_score"].fillna(0.0)
    df["rank_keyterm_p1"] = rank_within_groups(df["keyterm_score"], df["protein1"], ascending=False)
    df = add_delta(df, "rank_keyterm_p1")

    # ----- Save -----
    Path(OUT_TSV).parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(OUT_TSV, sep="\t", index=False)
    logging.info(f"Wrote: {OUT_TSV}")

    # Quick per-p1 preview
    for p1, g in df.groupby("protein1"):
        top = g.sort_values("rank_tfidf_p1").head(3)[["similar_protein_name","tfidf_score","rank_tfidf_p1"]]
        logging.info(f"\n[p1={p1}] Top by TF-IDF:\n{top.to_string(index=False)}")

if __name__ == "__main__":
    main()
