# main.py
# -*- coding: utf-8 -*-

import os
import argparse
import pandas as pd
import numpy as np
import torch
from config import *
from prepare_named_data import prepare_v11_named
from extract_known_partners import extract_known_partners
from embedding_utils import load_embeddings, resolve_h5_key
from model_helpers import build_dscript_model, load_dscript_weights
from region_guidance import rank_candidates_interpretability_guided

# ---------------- utilities ----------------

def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)

def load_known_partners_file(p1_name: str):
    path = os.path.join(OUTPUT_DIR, p1_name, "v11_partners.tsv")
    if not os.path.isfile(path):
        return []
    df = pd.read_csv(path, sep="\t")
    col = "known_partner" if "known_partner" in df.columns else df.columns[0]
    return df[col].dropna().astype(str).tolist()

def pooled_vec_2d(z2d: np.ndarray) -> np.ndarray:
    """L2 per-residue -> mean pool -> L2 (like your prefilter helper)."""
    z = z2d[0] if z2d.ndim == 3 else z2d
    z = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1e-12)
    v = z.mean(axis=0)
    v = v / (np.linalg.norm(v) + 1e-12)
    return v.astype(np.float32)

# ---------------- main pipeline ----------------

def run_region_guidance_for_p1(
    p1_name: str,
    all_emb: dict,
    alias_path: str,
    weights_path: str,
    out_dir: str,
    window_min: int = 6,
    window_max: int = 0,   # 0 -> no max
    smooth_w: int = 5,
    device_str: str = "",
    precision: str = "fp32",
    prefilter_top: int = 2000,
    sim_thresh: float = 0.5,   # threshold 
):
    """
    Interpretability-guided R_p ranking for p1 using its known partners as anchors.
    Candidates: all proteins except p1 and anchors, prefiltered by pooled-cosine top-K.

    Saves:
      - {out_dir}/top_similars.tsv  (rows with cosine >= sim_thresh for each kp vs all candidates)
      - {out_dir}/region_guided/Rp_ranked_candidates.tsv
      - {out_dir}/region_guided/<p1_key>__<pk_key>__C_hat.npy / __pk_region_idx.npy
    """
    ensure_dir(out_dir)
    rg_dir = os.path.join(out_dir, "region_guided")
    ensure_dir(rg_dir)

    # alias map for preferred names
    aliases_v12 = pd.read_csv(alias_path, sep="\t")
    id2name = dict(zip(aliases_v12["#string_protein_id"], aliases_v12["preferred_name"]))

    def key_to_name(k: str) -> str:
        # fallbacks if key not present
        return id2name.get(k, k.split("9606.")[-1] if k.startswith("9606.") else k)

    keys = set(all_emb.keys())
    # resolve p1
    try:
        p_key = resolve_h5_key(p1_name, keys, alias_path=alias_path, strict_id_only=False)
    except KeyError:
        if f"9606.{p1_name}" in keys:
            p_key = f"9606.{p1_name}"
        else:
            print(f"[region] {p1_name} missing in embeddings: skip.")
            return

    Z_p = all_emb[p_key]
    # build DSCRIPT once
    emb_nin = (Z_p[0].shape[1] if Z_p.ndim == 3 else Z_p.shape[1])
    device = torch.device(device_str) if device_str else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_dscript_model(emb_nin=emb_nin, use_cuda=(device.type == "cuda"))
    model = load_dscript_weights(model, weights_path)

    # anchors from v11_partners.tsv
    kp_names = load_known_partners_file(p1_name)
    if not kp_names:
        print(f"no known partners for {p1_name}: skip.")
        return

    anchors = {}
    for nm in kp_names:
        try:
            k = resolve_h5_key(nm, keys, alias_path=alias_path, strict_id_only=False)
            anchors[k] = all_emb[k]
        except KeyError:
            if f"9606.{nm}" in keys:
                anchors[f"9606.{nm}"] = all_emb[f"9606.{nm}"]

    if not anchors:
        print(f"none of the partners for {p1_name} resolved to embeddings: skip.")
        return

    # candidates: all except p1 + anchors (build matrix of pooled vectors)
    names, vecs = [], []
    for k, arr in all_emb.items():
        if k == p_key or k in anchors:
            continue
        try:
            names.append(k)
            vecs.append(pooled_vec_2d(arr))
        except Exception:
            pass
    M = np.stack(vecs, axis=0) if vecs else np.zeros((0, 1), dtype=np.float32)

    # write top_similars.tsv from per-anchor cosine 
    sim_rows = []
    if len(M):
        for ak, Zk in anchors.items():
            a_vec = pooled_vec_2d(Zk)
            sims = M @ a_vec  # cosine with all candidates
            for cand_key, s in zip(names, sims):
                if s >= sim_thresh:
                    sim_rows.append({
                        "p1_name": p1_name,
                        "known_partner": key_to_name(ak),
                        "similar_protein": key_to_name(cand_key),
                        "similarity_score": float(s),
                    })
    pd.DataFrame(sim_rows).to_csv(
        os.path.join(out_dir, "top_similars.tsv"),
        sep="\t", index=False
    )
    print(f"{p1_name}: saved top_similars.tsv (n={len(sim_rows)})  thresh={sim_thresh}")

    # ------------- prefilter for region-guided ranking (cheap heuristic) -------------
    # quick prefilter w.r.t. the FIRST anchor’s pooled vector
    first_anchor_key = next(iter(anchors.keys()))
    q_vec = pooled_vec_2d(anchors[first_anchor_key])
    scores = M @ q_vec
    order = np.argsort(-scores)[:prefilter_top] if len(scores) else np.array([], dtype=int)
    cand_keys = [names[i] for i in order]

    candidates = {k: all_emb[k] for k in cand_keys}
    win_max = None if (window_max is None or window_max <= 0) else window_max

    rows, meta = rank_candidates_interpretability_guided(
        model=model,
        Z_p=Z_p,
        anchors=anchors,
        candidates=candidates,
        device=device,
        precision=precision,
        smooth_w=smooth_w,
        window_min=window_min,
        window_max=win_max,
    )

    # save per-anchor diagnostics
    for pk_id, info in meta["anchor_regions"].items():
        base = os.path.join(rg_dir, f"{p_key}__{pk_id}")
        np.save(base + "__C_hat.npy", info["C_hat"])
        np.save(base + "__pk_region_idx.npy", info["idx_region"])

    # save ranking
    df = pd.DataFrame(rows).sort_values("score_Rp", ascending=False).reset_index(drop=True)
    df.to_csv(os.path.join(rg_dir, "Rp_ranked_candidates.tsv"), sep="\t", index=False)
    print(f"{p1_name}: saved Rp_ranked_candidates.tsv (n={len(df)})")

def main():
    ap = argparse.ArgumentParser(description="PPI Candidate Ranking")
    ap.add_argument("--prepare_once", action="store_true",
                    help="Run prepare_v11_named() before extracting partners.")
    ap.add_argument("--run_region_guidance", action="store_true",
                    help="Also compute interpretability-guided ranking per p1.")
    ap.add_argument("--rg_prefilter_top", type=int, default=2000,
                    help="Candidates kept after pooled-cosine prefilter (region guidance).")
    ap.add_argument("--rg_window_min", type=int, default=6)
    ap.add_argument("--rg_window_max", type=int, default=0, help="0 = no max length.")
    ap.add_argument("--rg_smooth_w", type=int, default=5)
    ap.add_argument("--rg_device", default="", help="e.g., cuda:0 or cpu")
    ap.add_argument("--rg_precision", default="fp32", choices=["fp32","bf16","fp16"])
    ap.add_argument("--top_n", type=int, default=N_CANDIDATES, help="How many p1 from TOP_PROTEINS.")
    ap.add_argument("--sim_thresh", type=float, default=0.5,
                    help="Cosine threshold for writing top_similars.tsv (per p1, per known partner).")
    args = ap.parse_args()

    print("Starting PPI Candidate Ranking\n")

    # 0) prepare v11 named
    if args.prepare_once:
        prepare_v11_named()

    # 1) extract partners + v12 targets for ALL top proteins
    extract_known_partners()

    # 2) load embeddings once
    all_embeddings = None
    if args.run_region_guidance:
        print("Loading embeddings once for region guidance…")
        all_embeddings = load_embeddings(EMBEDDING_FILE)

    # 3) iterate over top proteins
    top_proteins = pd.read_csv(TOP_PROTEINS, sep="\t").head(args.top_n)
    for i, row in top_proteins.iterrows():
        p1_name = row["protein1_name"]
        out_dir = os.path.join(OUTPUT_DIR, p1_name)
        ensure_dir(out_dir)

        # region-guided ranking 
        if args.run_region_guidance:
            run_region_guidance_for_p1(
                p1_name=p1_name,
                all_emb=all_embeddings,
                alias_path=ALIAS_V12,          
                weights_path=DSCRIPT_MODEL,
                out_dir=out_dir,
                window_min=args.rg_window_min,
                window_max=args.rg_window_max,
                smooth_w=args.rg_smooth_w,
                device_str=args.rg_device,
                precision=args.rg_precision,
                prefilter_top=args.rg_prefilter_top,
                sim_thresh=args.sim_thresh,
            )

        print(f"[{i+1}/{len(top_proteins)}] {p1_name}: done")

    print("\nDone.")

if __name__ == "__main__":
    main()