# run_region_guidance_pairs.py
# -*- coding: utf-8 -*-

import os
import argparse
import pandas as pd
import numpy as np
import torch
import time
from config import *
from region_guidance import *   
from embedding_utils import *
from model_helpers import *

try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

def load_embeddings(h5_path: str):
    out = {}
    with h5py.File(h5_path, "r") as f:
        for k in f.keys():
            out[k] = np.array(f[k])
    return out

def _ensure_3d(z: np.ndarray) -> np.ndarray:
    if z.ndim == 2:
        z = z[None, ...]
    if z.ndim != 3:
        raise ValueError(f"Expected (1, L, D) or (L, D); got {z.shape}")
    return z

def _as_2d(z: np.ndarray) -> np.ndarray:
    return z[0] if z.ndim == 3 else z

# ===== fast cosine prefilter (pooled) =====
def _pooled_vec_2d(z2d: np.ndarray, mode: str = "mean") -> np.ndarray:
    """
    z2d: (L, D) or (1, L, D). L2-normalize per residue, then mean pool.
    Returns a L2-normalized vector (D,).
    Used for DEBUG
    """
    z = z2d[0] if z2d.ndim == 3 else z2d
    z = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1e-12)
    v = z.mean(axis=0) if mode == "mean" else np.median(z, axis=0)
    v = v / (np.linalg.norm(v) + 1e-12)
    return v.astype(np.float32)

def _build_pooled_matrix(all_emb: dict[str, np.ndarray]):
    names, vecs = [], []
    for k, arr in all_emb.items():
        try:
            v = _pooled_vec_2d(arr)
            names.append(k); vecs.append(v)
        except Exception:
            continue
    M = np.stack(vecs, axis=0)  # (N, D), already L2-normalized
    return names, M

def _cosine_prefilter(query_vec: np.ndarray,
                      names: list[str],
                      M: np.ndarray,
                      top: int | None = 2000,
                      thresh: float | None = None,
                      exclude: set[str] = frozenset()):
    s = M @ query_vec.astype(np.float32)  # (N,)
    pairs = [(n, float(val)) for n, val in zip(names, s) if n not in exclude]
    if thresh is not None:
        pairs = [p for p in pairs if p[1] >= thresh]
    pairs.sort(key=lambda x: x[1], reverse=True)
    return pairs if top is None else pairs[:top]

# ===== v12 targets helper (rediscovery metric) =====
def v12_targets_tsv_path_for(p_name: str):
    return os.path.join(OUTPUT_DIR, p_name, "v12_targets.tsv")

def load_v12_targets_for_p(p_name: str, *, alias_path: str, h5_keys: set):
    """
    Returns a set of resolvable IDs (present in H5) for 'new v12 targets' of p.
    Accepts v12_targets.tsv with columns 'target' or 'protein_external_id'/'preferred_name'/'protein2'/'protein2_name'.
    """
    path = v12_targets_tsv_path_for(p_name)
    if not os.path.isfile(path):
        print(f"[v12] {path} not found: no v12 targets for {p_name}.")
        return set()

    df = pd.read_csv(path, sep="\t")
    cols = set(df.columns)
    candidate_cols = [c for c in ["target","protein_external_id","preferred_name","protein2","protein2_name"] if c in cols]
    if not candidate_cols:
        print(f"{path}: unrecognized columns. Columns: {sorted(cols)}")
        return set()

    names = df[candidate_cols[0]].dropna().astype(str).tolist()

    resolved = set()
    for x in names:
        try:
            kid = resolve_h5_key(x, h5_keys, alias_path=alias_path, strict_id_only=False)
            resolved.add(kid)
        except KeyError:
            if not x.startswith("9606.") and f"9606.{x}" in h5_keys:
                resolved.add(f"9606.{x}")
    if not resolved:
        print(f"No v12 targets resolved to IDs for {p_name}.")
    return resolved

# ===== partners helpers =====
def partners_tsv_path_for(p_name: str):
    return os.path.join(OUTPUT_DIR, p_name, "v11_partners.tsv")

def fallback_partners_from_v11_named(p_name: str, top_n: int = 50):
    """
    If {OUTPUT_DIR}/v11_named.tsv exists, collect rows where p_name appears
    and return the counterpart names (up to top_n), sorted by score if available.
    """
    v11_named_path = os.path.join(OUTPUT_DIR, "v11_named.tsv")
    if not os.path.isfile(v11_named_path):
        print(f"[fallback] {v11_named_path} is missing.")
        return []

    df = pd.read_csv(v11_named_path, sep="\t")
    score_col = "combined_score" if "combined_score" in df.columns else ("score" if "score" in df.columns else None)
    mask = (df.get("protein1_name") == p_name) | (df.get("protein2_name") == p_name)
    sub = df.loc[mask].copy()
    if sub.empty:
        print(f"[fallback] No occurrences of {p_name} in v11_named.tsv.")
        return []

    sub["partner"] = np.where(sub["protein1_name"] == p_name, sub["protein2_name"], sub["protein1_name"])
    sub = sub.dropna(subset=["partner"])
    if score_col and score_col in sub.columns:
        sub = sub.sort_values(by=score_col, ascending=False)

    partners, seen = [], set()
    for x in sub["partner"].astype(str).tolist():
        if x not in seen and x != p_name:
            seen.add(x); partners.append(x)
        if len(partners) >= top_n:
            break
    print(f"[fallback] Found {len(partners)} fallback partners for {p_name}.")
    return partners

def load_known_partners_for_p(p_name: str):
    """
    Try {OUTPUT_DIR}/{p}/v11_partners.tsv (expects 'known_partner' col),
    else fallback to v11_named.tsv.
    """
    path = partners_tsv_path_for(p_name)
    partners: list[str] = []
    if os.path.isfile(path):
        df = pd.read_csv(path, sep="\t")
        if "known_partner" in df.columns:
            partners = df["known_partner"].dropna().astype(str).tolist()
        else:
            print(f"{path} missing 'known_partner' column. Columns: {list(df.columns)}")

    if not partners:
        print(f"No partners in {path}. Falling back to v11_named.tsv ...")
        partners = fallback_partners_from_v11_named(p_name, top_n=50)
    return partners

# optional auto-extraction
def ensure_partners_ready(p_name: str, auto_extract: bool = True):
    """
    If OUTPUT_DIR/<p>/v11_partners.tsv is missing and auto_extract=True,
    try to call extract_known_partners.extract_known_partners() to generate it.
    """
    tsv = partners_tsv_path_for(p_name)
    if os.path.isfile(tsv):
        return
    if not auto_extract:
        print(f"[warn] Missing partners file: {tsv} (auto_extract=OFF).")
        return

    print(f"[auto_extract] {tsv} not found. Running extract_known_partners()...")
    try:
        from extract_known_partners import extract_known_partners
        extract_known_partners()
    except Exception as e:
        print("Failed to generate partners file:", e)

# ===== selection of p from TOP_PROTEINS =====
def choose_p_name(top_proteins_tsv: str, n_candidates: int, index: int) -> str:
    df = pd.read_csv(top_proteins_tsv, sep="\t").head(n_candidates)
    if not (0 <= index < len(df)):
        raise IndexError(f"--p_index={index} out of range (0..{len(df)-1})")
    return str(df.iloc[index]["protein1_name"])

# ===== main routine for a list of anchors (partners) =====
def run_interpretability_guided_for_pairs(
    p_name: str,
    partners: list[str],
    *,
    h5: str = EMBEDDING_FILE,
    weights: str = DSCRIPT_MODEL,
    out_root: str = os.path.join(OUTPUT_DIR, "region_guided_pairs"),
    limit: int | None = None,
    # interpretability-guided hyperparams
    window_min: int = 6,
    window_max: int = 0,   # use 0 for "no max"
    smooth_w: int = 5,
    alias_path: str = "",
    device: str = "",
    precision: str = "fp32",
    # prefilter
    prefilter_top: int | None = 2000,
    prefilter_thresh: float | None = None,
):
    os.makedirs(out_root, exist_ok=True)

    dev = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {dev.type}{'' if dev.type=='cpu' else (':' + str(dev.index) if dev.index is not None else '')} | precision={precision}")

    # load all embeddings once
    all_emb = load_embeddings(h5)
    keys = set(all_emb.keys())

    # resolve p_name → H5 key
    try:
        p_key = resolve_h5_key(p_name, keys, alias_path=alias_path)
    except KeyError:
        if f"9606.{p_name}" in keys:
            p_key = f"9606.{p_name}"
        else:
            some = ", ".join(list(keys)[:5])
            raise KeyError(f"{p_name} not found in {h5}. Example keys: {some}")

    # build DSCRIPT once
    emb_nin = int(_as_2d(all_emb[p_key]).shape[1])
    model = build_dscript_model(emb_nin=emb_nin, use_cuda=(dev.type == "cuda"))
    model = load_dscript_weights(model, weights)

    # precompute pooled matrix for **all** (for cosine prefilter) once
    pooled_names, pooled_M = _build_pooled_matrix(all_emb)
    pooled_index = {n: i for i, n in enumerate(pooled_names)}

    if limit is not None:
        partners = partners[:limit]

    # resolve partner names
    resolved_partners: list[tuple[str, str]] = []
    for pk in partners:
        try:
            pk_key = resolve_h5_key(pk, keys, alias_path=alias_path)
            resolved_partners.append((pk, pk_key))
        except KeyError:
            if f"9606.{pk}" in keys:
                resolved_partners.append((pk, f"9606.{pk}"))
            else:
                print(f"pk='{pk}' not found in embeddings: SKIP")

    if not resolved_partners:
        print(f"No resolvable partners for {p_name}. Skipping.")
        return

    # optional v12 rediscovery metric
    v12_targets_ids = set()
    first_pk = True
    summary_rows = []

    # window_max: allow "no max" with 0 or negative
    win_max = None if window_max <= 0 else window_max

    for i, (pk_disp, pk_key) in enumerate(resolved_partners, 1):
        try:
            if pk_key not in all_emb:
                print(f"[{i}/{len(resolved_partners)}] p={p_name} pk={pk_disp} -> SKIP (missing embedding)")
                continue

            # --- PREFILTER: cosine on pooled embeddings ---
            if pk_key in pooled_index:
                q_vec = pooled_M[pooled_index[pk_key]]
            else:
                q_vec = _pooled_vec_2d(_as_2d(all_emb[pk_key]))

            pref = _cosine_prefilter(
                q_vec,
                pooled_names,
                pooled_M,
                top=prefilter_top,
                thresh=prefilter_thresh,
                exclude={p_key, pk_key}
            )
            cand_keys = [n for (n, _) in pref]
            if not cand_keys:
                print(f"[{i}/{len(resolved_partners)}] p={p_name} pk={pk_disp} -> no candidates after prefilter.")
                continue

            candidates = {k: all_emb[k] for k in cand_keys}
            Z_p  = all_emb[p_key]
            Z_pk = all_emb[pk_key]

            # ===== TIMER START =====
            t0 = time.perf_counter()

            # --- INTERPRETABILITY-GUIDED ---
            rows, meta = rank_candidates_interpretability_guided(
                model=model,
                Z_p=Z_p,
                anchors={pk_key: Z_pk},
                candidates=candidates,
                device=dev,
                precision=precision,
                smooth_w=smooth_w,
                window_min=window_min,
                window_max=win_max,
            )

            elapsed = time.perf_counter() - t0
            print(f"p={p_name} pk={pk_disp} took {elapsed:.2f}s "
                f"({len(cand_keys)} candidates)")

            # outputs for pair (p, pk)
            pair_out = os.path.join(out_root, f"{p_name}__{pk_disp}")
            os.makedirs(pair_out, exist_ok=True)

            # save per-anchor diagnostics (C_hat + region)
            # meta['anchor_regions'] contains a dict per anchor; here only pk_key
            info = meta["anchor_regions"][pk_key]
            np.save(os.path.join(pair_out, "C_hat.npy"), info["C_hat"])
            np.save(os.path.join(pair_out, "pk_region_idx.npy"), info["idx_region"])

            # save ranking
            df = pd.DataFrame(rows)

            # rows already contain candidate_id, score_Rp, argmax_anchor, best_window_*
            df = df.sort_values("score_Rp", ascending=False).reset_index(drop=True)
            df["rank"] = np.arange(1, len(df) + 1)
            out_tsv = os.path.join(pair_out, "ranked_candidates.tsv")
            df.to_csv(out_tsv, sep="\t", index=False)

            # v12 targets (load once)
            if first_pk:
                v12_targets_ids = load_v12_targets_for_p(p_name, alias_path=alias_path, h5_keys=keys)
                first_pk = False

            # rediscovery report on shortlist
            if not df.empty and v12_targets_ids:
                found = df[df["candidate_id"].isin(v12_targets_ids)].copy()
                if not found.empty:
                    redisc_path = os.path.join(pair_out, "rediscovered_v12.tsv")
                    found.to_csv(redisc_path, sep="\t", index=False)

            # terse log
            top1 = df.iloc[0].to_dict() if df.shape[0] else None
            if top1:
                cand_val = top1.get("candidate_id", "NA")
                score_val = top1.get("score_Rp", float("nan"))
                j0, j1 = info["bounds"]
                print(
                    f"[{i}/{len(resolved_partners)}] p={p_name}  pk={pk_disp}  "
                    f"region=[{j0},{j1}]  TOP1={cand_val}  R_p={score_val:.4f}  (shortlist={len(cand_keys)})"
                )
                summary_rows.append({
                    "p": p_name,
                    "p_known": pk_disp,
                    "pk_region_j0": j0,
                    "pk_region_j1": j1,
                    "top1_candidate": cand_val,
                    "top1_Rp": score_val,
                    "shortlist_size": len(cand_keys),
                    "ranked_file": out_tsv,
                })
            else:
                print(f"[{i}/{len(resolved_partners)}] p={p_name}  pk={pk_disp} -> no candidates scored.")

        except Exception as e:
            print(f"[error] p={p_name} pk={pk_disp}: {e}  (continuing)")

    # summary per p
    if summary_rows:
        pd.DataFrame(summary_rows).to_csv(
            os.path.join(out_root, f"{p_name}__summary.tsv"), sep="\t", index=False
        )

    # aggregate v12 metrics across anchors
    if v12_targets_ids:
        redisc_all = []
        for pk_disp, _pk_key in resolved_partners:
            redisc_path = os.path.join(out_root, f"{p_name}__{pk_disp}", "rediscovered_v12.tsv")
            if os.path.isfile(redisc_path):
                dfr = pd.read_csv(redisc_path, sep="\t")
                dfr["kp"] = pk_disp
                redisc_all.append(dfr)
        if redisc_all:
            R = pd.concat(redisc_all, ignore_index=True)
            best = (R[["candidate_id","rank"]]
                    .groupby("candidate_id", as_index=False)["rank"].min()
                    .rename(columns={"rank":"best_rank"}))
            best["is_hit@1"]  = (best["best_rank"] <= 1).astype(int)
            best["is_hit@5"]  = (best["best_rank"] <= 5).astype(int)
            best["is_hit@10"] = (best["best_rank"] <= 10).astype(int)
            total_v12 = len(v12_targets_ids)
            hits = {
                "p": p_name,
                "n_v12_targets_resolvable": total_v12,
                "n_v12_rediscovered_any": int(best.shape[0]),
                "hit_ratio": (best.shape[0] / total_v12) if total_v12 > 0 else 0.0,
                "hit@1":  best["is_hit@1"].mean()  if total_v12 > 0 else 0.0,
                "hit@5":  best["is_hit@5"].mean()  if total_v12 > 0 else 0.0,
                "hit@10": best["is_hit@10"].mean() if total_v12 > 0 else 0.0,
                "mean_best_rank": float(best["best_rank"].mean()) if best.shape[0] else np.nan,
            }
            best.to_csv(os.path.join(out_root, f"{p_name}__v12_rediscovery_best.tsv"), sep="\t", index=False)
            pd.DataFrame([hits]).to_csv(os.path.join(out_root, f"{p_name}__v12_metrics.tsv"), sep="\t", index=False)
        else:
            pd.DataFrame([{
                "p": p_name,
                "n_v12_targets_resolvable": len(v12_targets_ids),
                "n_v12_rediscovered_any": 0,
                "hit_ratio": 0.0,
                "hit@1": 0.0, "hit@5": 0.0, "hit@10": 0.0,
                "mean_best_rank": np.nan,
            }]).to_csv(os.path.join(out_root, f"{p_name}__v12_metrics.tsv"), sep="\t", index=False)

# ===== iterate over TOP_PROTEINS =====
def _iter_top_proteins(top_proteins_tsv: str, start_index: int, count: int | None):
    df = pd.read_csv(top_proteins_tsv, sep="\t")
    if count is None:
        df = df.iloc[start_index:]
    else:
        df = df.iloc[start_index:start_index+count]
    for _, row in df.iterrows():
        yield str(row["protein1_name"])

# ===== choose parameters to run
def main():
    ap = argparse.ArgumentParser(
        description="Interpretability-guided ranking for (p, p_known) using DSCRIPT Ĉ"
    )
    ap.add_argument("--p_name", default="", help="Name/ID of p (as in HDF5). If empty, use TOP_PROTEINS.")
    ap.add_argument("--p_index", type=int, default=0, help=f"Index in TOP_PROTEINS (0..{N_CANDIDATES-1}) when using a single p.")
    ap.add_argument("--all", action="store_true", help="Process all (or a range) from TOP_PROTEINS.")
    ap.add_argument("--start_index", type=int, default=0, help="Start index when using --all (default 0).")
    ap.add_argument("--count", type=int, default=None, help="How many proteins to process from start_index (default: to the end).")

    ap.add_argument("--alias", default="", help="STRING info TSV (e.g., 9606.protein.info.v11.0.txt) for alias resolution.")
    ap.add_argument("--limit", type=int, default=None, help="Limit number of known partners per p (default: all).")

    ap.add_argument("--h5", default=EMBEDDING_FILE, help=f"HDF5 embeddings (default: {EMBEDDING_FILE})")
    ap.add_argument("--weights", default=DSCRIPT_MODEL, help=f"DSCRIPT weights (default: {DSCRIPT_MODEL})")
    ap.add_argument("--out_root", default=os.path.join(OUTPUT_DIR, "region_guided_pairs"), help="Output root directory.")

    # interpretability-guided hyperparameters (region on p_k with highest avg activation)
    ap.add_argument("--window_min", type=int, default=6)
    ap.add_argument("--window_max", type=int, default=35, help="Use 0 for no max (full length allowed).")
    ap.add_argument("--smooth_w", type=int, default=5)

    ap.add_argument("--no_auto_extract", action="store_true", help="Do not attempt automatic partner extraction.")

    ap.add_argument("--device", default="", help="e.g., 'cuda:0' or 'cpu' (default: auto)")
    ap.add_argument("--precision", default="fp32", choices=["fp32","fp16","bf16"],
                    help="AMP precision for DSCRIPT forward (default fp32=off).")

    ap.add_argument("--prefilter_top", type=int, default=2000,
                    help="How many candidates to keep after pooled-cosine prefilter (default 2000).")
    ap.add_argument("--prefilter_thresh", type=float, default=None,
                    help="Cosine threshold [0..1] for prefilter (default: None).")

    args = ap.parse_args()
    auto_extract = not args.no_auto_extract

    # choose proteins
    if args.all:
        p_list = list(_iter_top_proteins(TOP_PROTEINS, args.start_index, args.count))
    else:
        p_name = args.p_name or choose_p_name(TOP_PROTEINS, N_CANDIDATES, args.p_index)
        p_list = [p_name]

    for idx, p_name in enumerate(p_list, 1):
        print(f"\n=== [{idx}/{len(p_list)}] p={p_name} ===")
        ensure_partners_ready(p_name, auto_extract=auto_extract)

        partners = load_known_partners_for_p(p_name)
        if not partners:
            print(f"[skip] No known partners for {p_name}. Next protein...")
            continue

        run_interpretability_guided_for_pairs(
            p_name=p_name,
            partners=partners,
            h5=args.h5,
            weights=args.weights,
            out_root=args.out_root,
            limit=args.limit,
            window_min=args.window_min,
            window_max=0,
            smooth_w=args.smooth_w,
            alias_path=args.alias,
            device=6,
            precision=args.precision,
            prefilter_top=args.prefilter_top,
            prefilter_thresh=args.prefilter_thresh,
        )

if __name__ == "__main__":
    main()