# --------- DSCRIPT model loading ---------
import sys, importlib.util
from typing import Optional, Tuple, List, Dict, Any
import os
import numpy as np
import torch
import torch.nn as nn
from utils import *
from config import *

# --------- Math Utils ---------
def _ensure_3d(z: np.ndarray) -> np.ndarray:
    if z.ndim == 2:
        z = z[None, ...]
    if z.ndim != 3:
        raise ValueError(f"Expected (1, L, D) or (L, D); got {z.shape}")
    return z

def _as_2d(z: np.ndarray) -> np.ndarray:
    return z[0] if z.ndim == 3 else z
def to_tensor(z: np.ndarray, device: torch.device) -> torch.Tensor:
    z = _ensure_3d(z)
    return torch.from_numpy(z).float().to(device)

# ---------- DSCRIPT model import and building ---------
def import_dscript_model(path_hint: Optional[str] = None):
    try:
        from dscript.models.interaction import DSCRIPTModel  # package import
        return DSCRIPTModel
    except Exception:
        pass
    try:
        DSCRIPT_MODEL
        return DSCRIPT_MODEL
    except Exception:
        pass

    if path_hint is None:
        path_hint = DSCRIPT_INTERACTION
    if not os.path.isfile(path_hint):
        raise FileNotFoundError(f"interaction.py not found at: {path_hint}")

    spec = importlib.util.spec_from_file_location("dscript_models_interaction", path_hint)
    mod = importlib.util.module_from_spec(spec)
    assert spec and spec.loader
    spec.loader.exec_module(mod)
    if not hasattr(mod, "DSCRIPTModel"):
        raise ImportError(f"DSCRIPTModel not found in {path_hint}")
    return getattr(mod, "DSCRIPTModel")

def build_dscript_model(emb_nin: int, **kwargs):
    DSCRIPTModel = import_dscript_model(
        path_hint=kwargs.get("path_hint", DSCRIPT_INTERACTION)
    )
    return DSCRIPTModel(
        emb_nin=emb_nin,
        emb_nout=kwargs.get("emb_nout", 100),
        emb_dropout=0.0,
        con_embed_dim=kwargs.get("con_embed_dim", 100),
        con_hidden_dim=kwargs.get("con_hidden_dim", 50),
        con_width=kwargs.get("con_width", 7),
        use_cuda=kwargs.get("use_cuda", True),
        emb_activation=nn.ReLU(),
        con_activation=nn.Sigmoid(),
        do_w=kwargs.get("do_w", True),
        do_sigmoid=kwargs.get("do_sigmoid", True),
        do_pool=kwargs.get("do_pool", False),
        pool_size=kwargs.get("pool_size", 9),
        theta_init=kwargs.get("theta_init", 1.0),
        lambda_init=kwargs.get("lambda_init", 0.0),
        gamma_init=kwargs.get("gamma_init", 0.0),
    )

def load_dscript_weights(model: nn.Module, weights_path: str):
    from torch import serialization as torch_serial
    # try state_dict-like first (safe)
    try:
        obj = torch.load(weights_path, map_location="cpu", weights_only=True)
        if isinstance(obj, dict):
            for k in ("state_dict", "model_state_dict", "weights"):
                if k in obj and isinstance(obj[k], dict):
                    model.load_state_dict(obj[k], strict=False)
                    return model
            try:
                model.load_state_dict(obj, strict=False)
                return model
            except Exception:
                pass
    except Exception:
        pass
    # allowlist for pickled checkpoints
    allow = [type(model)]
    try:
        from dscript.models.interaction import DSCRIPTModel as _DS
        allow.append(_DS)
    except Exception:
        pass
    try:
        from dscript.models.interaction import ModelInteraction as _MI
        allow.append(_MI)
    except Exception:
        pass
    try:
        torch_serial.add_safe_globals(allow)
    except Exception:
        pass

    obj = torch.load(weights_path, map_location="cpu", weights_only=False)
    if isinstance(obj, dict):
        for k in ("state_dict", "model_state_dict", "weights", "net"):
            if k in obj and isinstance(obj[k], dict):
                model.load_state_dict(obj[k], strict=False)
                return model
        try:
            model.load_state_dict(obj, strict=False)
            return model
        except Exception:
            pass
    try:
        model.load_state_dict(obj.state_dict(), strict=False)
        return model
    except Exception as e:
        raise RuntimeError(
            f"Unrecognized weights format in {weights_path}. "
            "Provide a plain state_dict (torch.save(model.state_dict(), ...))."
        ) from e

# --------- DSCRIPT contact map usage ---------
@torch.no_grad()
def dscript_contact_map(model: nn.Module,
                        Zp: np.ndarray,
                        Zpk: np.ndarray,
                        device: Optional[torch.device] = None,
                        precision: str = "fp32"):
    """
    Returns C_hat in [0,1]^{n_p x n_{p_k}}.
    """
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    model = model.to(device).eval()
    tzp, tzpk = to_tensor(Zp, device), to_tensor(Zpk, device)

    ctx = pick_autocast(precision, device)
    with ctx:
        C = model.cpred(tzp, tzpk)  # (B,1,Lp,Lk) or (B,Lp,Lk)

    if C.ndim == 4:
        C = C[0, 0].detach().float().cpu().numpy()
    elif C.ndim == 3:
        C = C[0].detach().float().cpu().numpy()
    else:
        raise ValueError(f"Unexpected C shape: {tuple(C.shape)}")
    return np.clip(C, 0.0, 1.0)

def smooth_1d(x: np.ndarray, w: int) -> np.ndarray:
    if w <= 1:
        return x
    pad = w // 2
    xp = np.pad(x, (pad, pad), mode='reflect')
    ker = np.ones(w, dtype=np.float32) / float(w)
    return np.convolve(xp, ker, mode='valid')

def select_contiguous_region_max_avg(
    activation: np.ndarray,
    window_min: int = 6,
    window_max: Optional[int] = 35 # it can be set to None for no max
):
    """
    Pick the contiguous region on p_k with the highest *average* activation.
    If window_max is None or <= 0, allow up to full length.
    """
    Lk = int(activation.shape[0])
    if Lk == 0:
        return np.array([], dtype=int), (0, -1)

    w_min = max(1, min(window_min, Lk))
    if window_max is None or window_max <= 0:
        w_max = Lk
    else:
        w_max = max(w_min, min(window_max, Lk))

    best_score, best_j0, best_len = -np.inf, 0, w_min
    prefix = np.cumsum(np.r_[0.0, activation])
    for w in range(w_min, w_max + 1):
        sums = prefix[w:] - prefix[:-w]
        avgs = sums / float(w)
        j0 = int(np.argmax(avgs))
        sc = float(avgs[j0])
        # note the >= and tie-break by larger w
        if (sc > best_score) or (sc == best_score and w > best_len):
            best_score, best_j0, best_len = sc, j0, w

    j0, j1 = best_j0, best_j0 + best_len - 1
    idx = np.arange(j0, j1 + 1, dtype=int)
    return idx, (j0, j1)

def _flatten_and_l2norm(seg_2d: np.ndarray) -> np.ndarray:
    seg2d = _as_2d(seg_2d)      # (L, D)
    flat = seg2d.reshape(-1)    # (L*D,)
    n = np.linalg.norm(flat) + 1e-12
    return flat / n

def cosine_flattened_max_over_windows(
    Z_pk_region: np.ndarray,
    Z_cand: np.ndarray):
    """
    Implements:
      sim(p_c, p_k) = max_{i=0..n_c-|I_pk|}  < z_k[I_pk], z_c[i:i+|I_pk|] > /
                                            ( ||z_k[I_pk]||_2 * ||z_c[i:i+|I_pk|]||_2 )
    Returns (best_score, best_j0, best_j1).
    """
    R = _as_2d(Z_pk_region).shape[0]
    Zc2d = _as_2d(Z_cand)
    Lc = Zc2d.shape[0]

    if R <= 0:
        return -2.0, 0, -1

    # candidate shorter than region -> compare whole candidate
    if Lc < R:
        pk = _flatten_and_l2norm(Z_pk_region)
        cw = _flatten_and_l2norm(Zc2d)
        return float(pk @ cw), 0, Lc - 1

    pk = _flatten_and_l2norm(Z_pk_region)

    best_score, best_j0 = -2.0, 0
    for j0 in range(0, Lc - R + 1):
        win = Zc2d[j0:j0 + R]              # (R, D)
        cw = _flatten_and_l2norm(win)      # (R*D,)
        sc = float(pk @ cw)                # cosine
        if sc > best_score:
            best_score, best_j0 = sc, j0
    return best_score, best_j0, best_j0 + R - 1