# -*- coding: utf-8 -*-
"""
Runtime biosecurity guard (Mode 3 / Mode 4), faithful to the paper's design.

Decision: lexicographic priority (Eq. 8)
    BLAST  >  LONG-SEQUENCE  >  SEMANTIC  >  FUZZY  >  KEYWORD

Signals
- keyword         : tiered dictionaries (L1/L2/L3)
- fuzzy           : rapidfuzz.partial_ratio ≥ FUZZY_THRESH
- semantic        : Sentence-Transformers cosine ≥ SEM_THRESH (optional)
- long-sequence   : any ≥ MIN_SEQ chars (DNA/AA)
- BLAST           : placeholder (optional, off by default)

Environment overrides (optional; strings are case-insensitive):
- ICLR_GUARD_FZ_THRESH   : int, default 87
- ICLR_GUARD_SEM_THRESH  : float in [0,1], default 0.67
- ICLR_GUARD_MIN_SEQ     : int, default 60
- ICLR_GUARD_BLAST       : "1"/"true" to force BLAST path on (stub only)

The function returns:
    {
      "action": "allow" | "warn" | "block",
      "reason": "none" | "keyword" | "fuzzy" | "semantic" | "longseq" | "blast",
      "flags":  {"keyword":bool, "fuzzy":bool, "semantic":bool, "longseq":bool, "blast":bool},
      "redacted_text": str,
      # 便于调参与复现（不影响上层逻辑）：
      "thresholds": {"fuzzy":int, "semantic":float, "min_seq":int},
      "config": {"level":str, "use_semantic":bool, "use_blast":bool}
    }
"""

from typing import Dict, Any, List, Tuple
import os
import re

# --------------------- keyword tiers (与你提供的一致) -----------------------
L1 = [
    "bioweapon","weaponize","gain of function","reverse genetics",
    "viral vector","dual-use","select agent",
    "anthrax","smallpox","ricin","botulinum"
]
L2 = L1 + [
    "human pathogen","clinical isolate","biosafety level","bsl-3","bsl-4",
    "infectious clone","rescue a virus","serial passage",
    "mers","sars","h5n1","h7n9","highly pathogenic"
]
L3 = L2 + [
    "virus","viral","virion","capsid",
    "covid-19","sars-cov-2","influenza","hepatitis b","hiv"
]
VIRUS_FILTER_LEVELS = {"L1_custom": L1, "L2_human": L2, "L3_all": L3}
DEFAULT_RISK_TERMS = sorted(set(L3))

# ------------------------ regex for long sequences ---------------------------
def _mk_seq_re(min_len: int) -> Tuple[re.Pattern, re.Pattern]:
    # min_len 允许在运行时调整
    nuc = re.compile(rf"(?:[ACGTUacgtu]{{{min_len},}})")
    aa  = re.compile(rf"(?:[ACDEFGHIKLMNPQRSTVWY]{{{min_len},}})", re.I)
    return nuc, aa

def _extract_first_seq(text: str, min_len: int) -> Tuple[str, str]:
    if not text:
        return ("","")
    NUC_RE, AA_RE = _mk_seq_re(min_len)
    m = NUC_RE.search(text)
    if m: return ("nucl", m.group(0))
    m = AA_RE.search(text)
    if m: return ("aa", m.group(0))
    return ("","")

# ----------------------------- fuzzy matching --------------------------------
def _fuzzy_hit(text: str, keywords: List[str], threshold: int = 87) -> bool:
    try:
        from rapidfuzz import fuzz
    except Exception:
        return False
    low = (text or "").lower()
    return any(fuzz.partial_ratio(low, kw.lower()) >= threshold for kw in keywords)

# ----------------------------- semantic hit ----------------------------------
_SEM_MODEL = None
def _semantic_hit(text: str, keywords: List[str], sim_threshold: float = 0.67) -> bool:
    """
    Optional semantic hit using Sentence-Transformers cosine similarity.
    Disabled (returns False) if package is unavailable.
    """
    global _SEM_MODEL
    try:
        from sentence_transformers import SentenceTransformer
    except Exception:
        return False
    if not isinstance(text, str) or not text.strip():
        return False
    try:
        if _SEM_MODEL is None:
            _SEM_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
        embs = _SEM_MODEL.encode([text] + keywords, normalize_embeddings=True)
        q, bank = embs[0], embs[1:]
        # 因为已单位化，内积即余弦
        sims = q @ bank.T
        return bool((sims >= sim_threshold).any())
    except Exception:
        return False

# ----------------------------- small utils -----------------------------------
def _to_bool_env(val: str, default: bool = False) -> bool:
    if val is None: return default
    return str(val).strip().lower() in {"1","true","t","yes","y"}

def _clamp_sem(v: float) -> float:
    try:
        v = float(v)
        if v < 0: v = 0.0
        if v > 1: v = 1.0
        return v
    except Exception:
        return 0.67

# ----------------------------- main API --------------------------------------
def biosecurity_alignment_guard(text: str,
                                where: str = "pre",
                                mode: str = "strict",
                                virus_filter_level: str = "L2_human",
                                use_semantic: bool = False,
                                use_blast: bool = False,
                                min_seq_len: int = 60) -> Dict[str, Any]:
    """
    Decision priority (Eq. 8): BLAST > long-seq > semantic > fuzzy > keyword.
    BLAST is disabled by default (requires external DB).
    Returns dict with action in {"allow","block","warn"}, reason, and flags.
    """

    # ---- thresholds (可由环境变量覆盖) ----
    FUZZY_THRESH  = int(os.getenv("ICLR_GUARD_FZ_THRESH",  "87"))
    SEM_THRESH    = _clamp_sem(os.getenv("ICLR_GUARD_SEM_THRESH", "0.67"))
    MIN_SEQ       = int(os.getenv("ICLR_GUARD_MIN_SEQ",     str(min_seq_len)))
    FORCE_BLAST   = _to_bool_env(os.getenv("ICLR_GUARD_BLAST"), False)

    # ---- terms by level ----
    terms = sorted(set(VIRUS_FILTER_LEVELS.get(virus_filter_level, DEFAULT_RISK_TERMS)))
    low = (text or "").lower()

    # ---- individual signals ----
    kw_hit  = any(kw in low for kw in terms)
    fz_hit  = _fuzzy_hit(low, terms, threshold=FUZZY_THRESH)
    sem_hit = _semantic_hit(text, terms, sim_threshold=SEM_THRESH) if use_semantic else False

    seq_t, seq = _extract_first_seq(text or "", MIN_SEQ)
    long_hit = bool(seq) and (len(seq) >= MIN_SEQ)

    # BLAST: 占位逻辑（默认关闭；如需联通外部 DB，将此处改为真实比对）
    blast_hit = bool(FORCE_BLAST and use_blast)

    # ---- quick allow ----
    risky = kw_hit or fz_hit or sem_hit or long_hit or blast_hit
    flags = {"keyword": kw_hit, "fuzzy": fz_hit, "semantic": sem_hit, "longseq": long_hit, "blast": blast_hit}
    if not risky:
        return {
            "action": "allow",
            "reason": "none",
            "flags": flags,
            "redacted_text": text,
            "thresholds": {"fuzzy": FUZZY_THRESH, "semantic": SEM_THRESH, "min_seq": MIN_SEQ},
            "config": {"level": virus_filter_level, "use_semantic": bool(use_semantic), "use_blast": bool(use_blast)},
        }

    # ---- lexicographic priority ----
    if blast_hit:
        reason = "blast"
    elif long_hit:
        reason = "longseq"
    elif sem_hit:
        reason = "semantic"
    elif fz_hit:
        reason = "fuzzy"
    else:
        reason = "keyword"

    # ---- decision policy ----
    if mode == "strict":
        action = "block"
        redacted = "[REDACTED: BIOSECURITY VIOLATION]"
    else:
        action = "warn"
        redacted = "[⚠️ BIO RISK] " + (text or "")

    return {
        "action": action,
        "reason": reason,
        "flags": flags,
        "redacted_text": redacted,
        "thresholds": {"fuzzy": FUZZY_THRESH, "semantic": SEM_THRESH, "min_seq": MIN_SEQ},
        "config": {"level": virus_filter_level, "use_semantic": bool(use_semantic), "use_blast": bool(use_blast)},
    }
