# region_guidance.py
# -*- coding: utf-8 -*-

import os
import argparse
import csv
from typing import Dict, List, Tuple, Optional, Any
import re
from contextlib import nullcontext
import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from config import *
from embedding_utils import *
from candidate_retrieval.utils import *
from model_helpers import *

def _ensure_3d(z: np.ndarray) -> np.ndarray:
    if z.ndim == 2:
        z = z[None, ...]
    if z.ndim != 3:
        raise ValueError(f"Expected (1, L, D) or (L, D); got {z.shape}")
    return z

def _as_2d(z: np.ndarray) -> np.ndarray:
    return z[0] if z.ndim == 3 else z

# --------- Full pipeline ---------
@torch.no_grad()
def compute_anchor_region_indices(
    model: nn.Module,
    Z_p: np.ndarray,
    Z_pk: np.ndarray,
    device: torch.device,
    precision: str,
    smooth_w: int,
    window_min: int,
    window_max: Optional[int],   # Optional length
):
    """
    1) C_hat = cpred(p, p_k)
    2) activation over p_k residues = max over p rows
    3) smoothing
    4) select contiguous region with highest average activation
    """
    C_hat = dscript_contact_map(model, Z_p, Z_pk, device=device, precision=precision)   # (n_p x n_{p_k})
    col_act = C_hat.max(axis=0)                                                         # (n_{p_k},)
    col_smooth = smooth_1d(col_act, smooth_w)
    idx_region, (j0, j1) = select_contiguous_region_max_avg(col_smooth, window_min, window_max)
    return idx_region, (j0, j1), C_hat

def rank_candidates_interpretability_guided(
    model: nn.Module,
    Z_p: np.ndarray,
    anchors: Dict[str, np.ndarray],
    candidates: Dict[str, np.ndarray],
    device: torch.device,
    precision: str = "fp32",
    smooth_w: int = 5,
    window_min: int = 6,
    window_max: Optional[int] = 35,   # Optional length
):
    """
    For each anchor p_k:
      - compute I_{p_k} via contact map with p 
    For each candidate p_c:
      - score sim(p_c,p_k) via cosine 
    Then:
      - R_p(p_c) = max_k sim(p_c,p_k)
    """
    # Precompute anchor regions
    anchor_regions: Dict[str, Dict[str, Any]] = {}
    for pk_id, Z_pk in anchors.items():
        idx_region, (j0, j1), C_hat = compute_anchor_region_indices(
            model=model,
            Z_p=Z_p,
            Z_pk=Z_pk,
            device=device,
            precision=precision,
            smooth_w=smooth_w,
            window_min=window_min,
            window_max=window_max,
        )
        anchor_regions[pk_id] = {
            "idx_region": idx_region,
            "bounds": (int(j0), int(j1)),
            "C_hat": C_hat
        }

    # Rank candidates by max over anchors
    rows: List[Dict[str, Any]] = []
    for cid, Zc in candidates.items():
        best_over_anchors = -2.0
        best_anchor = None
        best_win = (0, -1)

        for pk_id, Z_pk in anchors.items():
            reg_idx = anchor_regions[pk_id]["idx_region"]
            if reg_idx.size == 0:
                continue
            Z_pk_region = _as_2d(Z_pk)[reg_idx]
            sc, j0, j1 = cosine_flattened_max_over_windows(Z_pk_region, Zc)
            if sc > best_over_anchors:
                best_over_anchors = sc
                best_anchor = pk_id
                best_win = (int(j0), int(j1))

        rows.append({
            "candidate_id": cid,
            "score_Rp": float(best_over_anchors),           # R_p(p_c)
            "argmax_anchor": (best_anchor or ""),
            "best_window_j0": best_win[0],
            "best_window_j1": best_win[1],
        })

    rows.sort(key=lambda r: r["score_Rp"], reverse=True)

    meta = {
        "anchor_regions": anchor_regions,   # contains per-anchor C_hat and region bounds/indices
        "window_min": int(window_min),
        "window_max": (None if window_max is None else int(window_max)),  
        "smooth_w": int(smooth_w),
    }
    return rows, meta

# set up argument parser
def main():
    ap = argparse.ArgumentParser(
        description="Interpretability-Guided Retrieval (matches LaTeX exactly)."
    )
    ap.add_argument("--p_name", required=True,
                    help="Target protein p: ID or alias resolvable in HDF5.")
    group = ap.add_mutually_exclusive_group(required=True)
    group.add_argument("--kp_names", default="",
                       help="Comma-separated anchors (IDs/aliases).")
    group.add_argument("--anchors_file", default="",
                       help="Text file with one anchor per line (IDs/aliases).")
    ap.add_argument("--cands_file", default="",
                    help="Optional file with candidate IDs (one per line). If omitted: CP(p)=all except p and KP.")
    ap.add_argument("--h5", default=EMBEDDING_FILE,
                    help=f"HDF5 embeddings path (default: {EMBEDDING_FILE})")
    ap.add_argument("--weights", default=DSCRIPT_MODEL,
                    help=f"DSCRIPT weights path (default: {DSCRIPT_MODEL})")
    ap.add_argument("--alias", default="",
                    help="STRING info TSV (e.g., 9606.protein.info.v11.0.txt) for alias resolution.")
    ap.add_argument("--outdir", default=os.path.join(OUTPUT_DIR, "region_guided_exact"),
                    help="Output directory")
    # region selection hyperparams (for contiguous region with highest average)
    ap.add_argument("--window_min", type=int, default=6)
    ap.add_argument(
        "--window_max", type=int, default=35,
        help="Max region length for anchors; use 0 for no max (up to full length)."
    )
    ap.add_argument("--smooth_w", type=int, default=5)
    # performance / runtime
    ap.add_argument("--top", type=int, default=50, help="How many candidates to print.")
    ap.add_argument("--device", default="", help="e.g., cuda:0 or cpu (default: auto).")
    ap.add_argument("--precision", default="fp32", choices=["fp32", "bf16", "fp16"],
                    help="AMP precision on GPU (fp32 disables AMP).")
    args = ap.parse_args()

    os.makedirs(args.outdir, exist_ok=True)

    # anchors
    if args.anchors_file and os.path.isfile(args.anchors_file):
        with open(args.anchors_file, "r") as f:
            kp_names = [ln.strip() for ln in f if ln.strip()]
    else:
        kp_names = [s.strip() for s in args.kp_names.split(",") if s.strip()]

    # candidates subset (optional)
    cand_list = None
    if args.cands_file and os.path.isfile(args.cands_file):
        with open(args.cands_file, "r") as f:
            cand_list = [ln.strip() for ln in f if ln.strip()]

    # load embeddings
    p_key, Z_p, anchors, candidates = get_embeddings_from_h5(
        args.h5, args.p_name, kp_names, cand_list=cand_list, alias_path=args.alias
    )

    # build & load DSCRIPT model (dimension from Z_p)
    emb_nin = int(_as_2d(Z_p).shape[1])
    device = pick_device(args.device)
    model = build_dscript_model(emb_nin=emb_nin, use_cuda=(device.type == "cuda"))
    model = load_dscript_weights(model, args.weights)

    # interpret 0 or negative as "no max"
    win_max = None if args.window_max <= 0 else args.window_max

    # run pipeline
    rows, meta = rank_candidates_interpretability_guided(
        model=model,
        Z_p=Z_p,
        anchors=anchors,
        candidates=candidates,
        device=device,
        precision=args.precision,
        smooth_w=args.smooth_w,
        window_min=args.window_min,
        window_max=win_max,   
    )

    # Save C_hat and region indices 
    for pk_id, info in meta["anchor_regions"].items():
        base = os.path.join(args.outdir, f"{p_key}__{pk_id}")
        np.save(base + "__C_hat.npy", info["C_hat"])
        np.save(base + "__pk_region_idx.npy", info["idx_region"])

    # Save ranking TSV (R_p) 
    df = pd.DataFrame(rows)
    tsv_path = os.path.join(args.outdir, f"{p_key}__Rp_ranked_candidates.tsv")
    df.to_csv(tsv_path, sep="\t", index=False, quoting=csv.QUOTE_NONE)

    # Print top candidates 
    print("\n=== TOP CANDIDATES (R_p) ===")
    for r in rows[:args.top]:
        print(f"{r['candidate_id']:<25s} R_p={r['score_Rp']:.4f}  "
              f"argmax_anchor={r['argmax_anchor']:<25s}  "
              f"win=[{r['best_window_j0']},{r['best_window_j1']}]")

    print(f"\nSaved outputs to: {args.outdir}")
    print(f"- Per-anchor: {p_key}__<pk_id>__C_hat.npy / __pk_region_idx.npy")
    print(f"- Ranking: {tsv_path}")

if __name__ == "__main__":
    main()