"""Utility functions for defense mechanisms."""

import os
import threading
import numpy as np
from typing import Optional, Tuple, Sequence
from sentence_transformers import SentenceTransformer
from .dj_defense import mislead_defense as _dj_mislead_defense


_SIMILARITY_MODEL_INIT_LOCK = threading.Lock()
_SIMILARITY_ENCODE_LOCK = threading.Lock()
_SIMILARITY_MODEL: Optional[SentenceTransformer] = None


def get_similarity_model() -> SentenceTransformer:
    """
    Thread-safe lazy initialization for a shared SentenceTransformer model.

    Notes:
    - `main.py` uses ThreadPoolExecutor; initializing per behavior would repeatedly load weights.
    - We keep a single shared model instance and guard `encode()` with a lock for safety.
    - Override via env:
        - SIMILARITY_MODEL_NAME (default: "sentence-transformers/paraphrase-mpnet-base-v2")
        - SIMILARITY_DEVICE (default: "cpu")
    """
    global _SIMILARITY_MODEL
    if _SIMILARITY_MODEL is not None:
        return _SIMILARITY_MODEL

    with _SIMILARITY_MODEL_INIT_LOCK:
        if _SIMILARITY_MODEL is not None:
            return _SIMILARITY_MODEL
        model_name = os.getenv("SIMILARITY_MODEL_NAME", "sentence-transformers/paraphrase-mpnet-base-v2")
        device = os.getenv("SIMILARITY_DEVICE", "cpu")
        _SIMILARITY_MODEL = SentenceTransformer(model_name, device=device)
        return _SIMILARITY_MODEL


def _encode_np(model: SentenceTransformer, texts: Sequence[str]) -> np.ndarray:
    """
    Thread-safe wrapper around SentenceTransformer.encode -> numpy array.
    """
    with _SIMILARITY_ENCODE_LOCK:
        return model.encode(list(texts), convert_to_numpy=True)


def _l2_normalize_rows(x: np.ndarray) -> np.ndarray:
    """
    L2-normalize a 2D array row-wise.
    Assumes embeddings are non-zero (true for SentenceTransformer in practice).
    """
    return x / np.linalg.norm(x, axis=1, keepdims=True)


def _cosine(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.dot(a, b))


def calculate_similarity(model: SentenceTransformer, text1: str, text2: str) -> float:
    """
    Compute cosine similarity between two texts.
    
    Args:
        model: SentenceTransformer model.
        text1: First text.
        text2: Second text.
    
    Returns:
        Similarity score in [-1, 1]. 1 means identical direction; typical values are in [0, 1].
    """
    if not text1 or not text2:
        return 0.0

    embs = _l2_normalize_rows(_encode_np(model, [text1, text2]))
    return _cosine(embs[0], embs[1])


def mislead_defense(
    original_response: str,
    prompt: str,
    client: Optional[object] = None,
    rewrite_model: str = "increase",
    max_attempts: int = 5,
    turn_index: Optional[int] = None,
    rewrite_server_url: Optional[str] = None,
    rewrite_server_url_increase: Optional[str] = None,
    rewrite_server_url_decrease: Optional[str] = None,
    direction: str = "increase",
    similarity_model: Optional[SentenceTransformer] = None,
    similarity_threshold: float = 0.8,
) -> Tuple[str, str, Optional[float]]:
    """
    Mislead defense: rewrite the response in the specified direction.
    
    Args:
        original_response: Original response string.
        prompt: Original prompt.
        client: OpenAI client (optional; kept for backward compatibility; can be None for server mode).
        rewrite_model: Rewrite model name (kept for backward compatibility; actual model selection may be env-driven).
        max_attempts: Maximum retry attempts.
        turn_index: Turn index (optional).
        rewrite_server_url: Server API URL (optional, backward compatible; if provided both directions use it).
        rewrite_server_url_increase: Server API URL for the "increase" direction (optional).
        rewrite_server_url_decrease: Server API URL for the "decrease" direction (optional).
        direction: Rewrite direction:
            - If explicitly provided ("increase"/"decrease"), use it.
            - Otherwise fall back to env var REWRITE_DIRECTION (default: "increase").
        similarity_model: SentenceTransformer model for similarity checks (optional).
        similarity_threshold: Similarity threshold (default 0.8); if below, regenerate.
    
    Returns:
        Tuple of (defended response string, rewrite direction, similarity score).
        Direction is "increase" or "decrease".
        Similarity is None (if not computed) or a float (final similarity score).
    """
    # NOTE: `client` and `similarity_model` are kept for backward-compatible signature.
    # The rewrite implementation is now centralized in `defense/dj_defense.py` (requests-based).
    _ = client
    _ = similarity_model
    return _dj_mislead_defense(
        original_response=original_response,
        prompt=prompt,
        rewrite_model=rewrite_model,
        max_attempts=max_attempts,
        turn_index=turn_index,
        rewrite_server_url=rewrite_server_url,
        rewrite_server_url_increase=rewrite_server_url_increase,
        rewrite_server_url_decrease=rewrite_server_url_decrease,
        direction=direction,
        similarity_threshold=similarity_threshold,
        include_query=False,
    )

