from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple

from defense.feedback_defense import DefenseTracker, MisleadDefenseRunner
from defense.guard_defense import GuardDefenseRunner, GuardSettings
from defense.proact_defense import ProActConfig, ProActDefense


ScoreFn = Callable[[str, str], Tuple[int, str]]
TruncateFn = Callable[[str], str]


@dataclass
class RewriteSettings:
    enabled: bool
    model: str
    server_url_increase: Optional[str]
    server_url_decrease: Optional[str]


@dataclass
class ProActSettings:
    enabled: bool
    # Optional overrides (defaults live in ProActConfig)
    base_model: Optional[str] = None
    analyzer_model: Optional[str] = None
    defender_model: Optional[str] = None
    evaluator_model: Optional[str] = None
    analyzer_votes: Optional[int] = None
    evaluator_votes: Optional[int] = None
    analyzer_margin_required: Optional[int] = None
    max_search_budget: Optional[int] = None
    verbose: Optional[bool] = None


@dataclass
class GuardSettingsConfig:
    enabled: bool


class ProActDefenseRunner:
    """
    Optional proactive defense wrapper (ProAct).
    - If enabled, returns ProAct's output given the conversation messages.
    - Intended to be chained before mislead (rewrite) defense.
    """

    def __init__(self, *, enabled: bool, config: Dict[str, Any]):
        self.enabled = bool(enabled)
        self._config = dict(config or {})
        self._defense: Optional[ProActDefense] = None

        if self.enabled:
            if not os.getenv("OPENAI_API_KEY"):
                raise EnvironmentError("OPENAI_API_KEY is required to enable ProAct defense")
            self._defense = ProActDefense(config=self._build_proact_config())

    def reset(self) -> None:
        # Stateless per request.
        return

    def _build_proact_config(self) -> ProActConfig:
        cfg = ProActConfig()
        overrides = dict(self._config.get("proact", {}) or {})
        for k in (
            "base_model",
            "analyzer_model",
            "defender_model",
            "evaluator_model",
            "analyzer_votes",
            "evaluator_votes",
            "analyzer_margin_required",
            "max_search_budget",
            "verbose",
        ):
            if k in overrides and overrides[k] is not None:
                setattr(cfg, k, overrides[k])
        return cfg

    def maybe_spurious(self, *, messages: List[Dict[str, str]]) -> Tuple[Optional[str], Dict[str, Any]]:
        if not self.enabled or self._defense is None:
            raise RuntimeError("ProActDefenseRunner called while disabled")
        spurious, info = self._defense.maybe_spurious(messages)
        # Keep only paper-relevant fields; outer layer already records `proact_enabled`.
        return spurious, info

class UpdaterDefenseHelper:
    """
    Shared defense hook for *all* updaters.

    Intent:
    - Apply the same mislead defense as Feedback, at the shared point right after we
      obtain a target model response (and before scoring / storing / TextGrad loss).
    - If defense is enabled, evaluation_score comes from the defended (rewritten) response score.
    """

    def __init__(
        self,
        *,
        config: Dict[str, Any],
        truncate_fn: TruncateFn,
        max_score: int = 5,  # Maximum evaluation score for jailbreak detection
    ) -> None:
        self._truncate = truncate_fn
        self.max_score = max_score
        self.tracker = DefenseTracker()
        self.settings = self._parse_settings(config)
        self.proact_settings = self._parse_proact_settings(config)
        self.guard_settings = self._parse_guard_settings(config)
        self.runner = MisleadDefenseRunner(
            enabled=self.settings.enabled,
            rewrite_model=self.settings.model,
            rewrite_server_url_increase=self.settings.server_url_increase,
            rewrite_server_url_decrease=self.settings.server_url_decrease,
            tracker=self.tracker,
        )
        self.proact_runner = ProActDefenseRunner(enabled=self.proact_settings.enabled, config=config)
        self.guard_runner = GuardDefenseRunner(GuardDefenseRunner.from_config(config))

    def reset(self) -> None:
        self.runner.enabled = bool(self.settings.enabled)
        self.runner.reset()
        self.proact_runner.reset()
        # Guard is stateless; keep enabled flag in settings.
        return

    @staticmethod
    def _parse_settings(config: Dict[str, Any]) -> RewriteSettings:
        """
        Mirrors Feedback's behavior:
        - enable_defense: bool
        - rewrite model selection via env REWRITE_MODEL_TYPE or config["rewrite_model"]
        - server URL selection via env REWRITE_SERVER_URL_{INCREASE,DECREASE} first, then config
        """
        enabled = bool(config.get("enable_defense", False))

        rewrite_cfg = dict(config.get("rewrite_model", {}) or {})
        rewrite_type = os.getenv("REWRITE_MODEL_TYPE") or rewrite_cfg.get("type", "server")

        # Model name selection:
        # - For server mode (vLLM/SGLang OpenAI-compatible), default to a served model name
        #   so we don't accidentally call "gpt-4.1-mini" against a local server.
        # - Allow override via env REWRITE_SERVER_MODEL for convenience in scripts.
        if rewrite_type == "server":
            # Prefer env override so run scripts can reliably line up with vLLM --served-model-name
            # (and avoid config accidentally specifying an OpenAI model name).
            model = os.getenv("REWRITE_SERVER_MODEL") or rewrite_cfg.get("model") or "increase"
        else:
            model = rewrite_cfg.get("model", "gpt-4.1-mini")

        server_url_increase = None
        server_url_decrease = None

        if rewrite_type == "server":
            server_url_increase = os.getenv("REWRITE_SERVER_URL_INCREASE") or rewrite_cfg.get("server_url_increase")
            server_url_decrease = os.getenv("REWRITE_SERVER_URL_DECREASE") or rewrite_cfg.get("server_url_decrease")

            # Back-compat: single URL applies to both, but we only require *one* URL to run.
            if not server_url_increase and not server_url_decrease:
                single = os.getenv("REWRITE_SERVER_URL") or rewrite_cfg.get("server_url")
                if single:
                    server_url_increase = single
                    server_url_decrease = None

            # Single-direction default: if only one is provided, use it.
            if server_url_increase and not server_url_decrease:
                server_url_decrease = None
            elif server_url_decrease and not server_url_increase:
                server_url_increase = server_url_decrease
                server_url_decrease = None

            # If defense is enabled but URL is missing, fall back to OpenAI mode
            # when OPENAI_API_KEY exists, to avoid hard failure.
            if enabled and (not server_url_increase):
                if os.getenv("OPENAI_API_KEY"):
                    server_url_increase = None
                    server_url_decrease = None
                else:
                    # Keep values as-is; mislead_defense will raise a clear error.
                    pass

        return RewriteSettings(
            enabled=enabled,
            model=model,
            server_url_increase=server_url_increase,
            server_url_decrease=server_url_decrease,
        )

    @staticmethod
    def _parse_proact_settings(config: Dict[str, Any]) -> ProActSettings:
        enabled = bool(config.get("enable_proact", False))
        proact_cfg = dict(config.get("proact", {}) or {})
        return ProActSettings(
            enabled=enabled,
            base_model=proact_cfg.get("base_model"),
            analyzer_model=proact_cfg.get("analyzer_model"),
            defender_model=proact_cfg.get("defender_model"),
            evaluator_model=proact_cfg.get("evaluator_model"),
            analyzer_votes=proact_cfg.get("analyzer_votes"),
            evaluator_votes=proact_cfg.get("evaluator_votes"),
            analyzer_margin_required=proact_cfg.get("analyzer_margin_required"),
            max_search_budget=proact_cfg.get("max_search_budget"),
            verbose=proact_cfg.get("verbose"),
        )

    @staticmethod
    def _parse_guard_settings(config: Dict[str, Any]) -> GuardSettingsConfig:
        # Simple on/off toggle; detailed settings are handled by GuardDefenseRunner.from_config()
        enabled = bool(config.get("enable_guard", False))
        return GuardSettingsConfig(enabled=enabled)

    def defend_score_build_result(
        self,
        *,
        harmful_behavior: str,
        prompt_for_rewrite: str,
        original_response_full: str,
        guard_messages: Optional[List[Dict[str, str]]] = None,
        # 0-based turn index of the attacker conversation. Used to decide whether to start mislead defense.
        # Typically: turn_index == 0 means the first attacker turn.
        turn_index: Optional[int] = None,
        attacker_text: str,
        score_fn: ScoreFn,
        proact_info: Optional[Dict[str, Any]] = None,
        raw_target_response: Optional[str] = None,
    ) -> Dict[str, Any]:
        """
        Build a result dict with the exact fields that updater code expects.
        """
        final_response = original_response_full

        # If ProAct produced a spurious response, set a fixed original_score=1 and
        # skip scoring the ORIGINAL response inside mislead. Still score defended response once.
        proact_used = bool(proact_info is not None and proact_info.get("intent") == "malicious")
        original_score_override = 1 if proact_used else None

        # Stage B0) Guard (optional): if the guard says response is unsafe, replace it with a fixed refusal.
        guard_info = None
        if self.guard_settings.enabled:
            if guard_messages is None:
                guard_messages = [
                    {"role": "user", "content": prompt_for_rewrite},
                    {"role": "assistant", "content": final_response},
                ]
            final_response, guard_info = self.guard_runner.maybe_guard(messages=guard_messages)

        # If guard replaced the response, do NOT apply mislead rewriting on top.
        # GuardDefenseRunner sets `replaced` when it swaps the response to a fixed refusal.
        # Back-compat: also treat label==True (unsafe) as replaced if present.
        guard_replaced = bool(
            guard_info is not None
            and (
                guard_info.get("replaced") is True
                or (guard_info.get("label") is True and guard_info.get("replacement_text") is not None)
            )
        )

        # Stage B) Mislead (optional): rewrite the (possibly proact-generated) response
        # Start mislead defense only on the first turn (turn_index == 0). If turn_index is not
        # provided, default to enabling mislead (backward compatibility for callers).
        # enable_mislead = (turn_index is None) or (int(turn_index) == 0)
        mislead_info = None
        if not guard_replaced:
            final_response, mislead_info = self.runner.maybe_defend(
                harmful_behavior=harmful_behavior,
                prompt_for_rewrite=prompt_for_rewrite,
                original_response_full=final_response,
                score_fn=score_fn,
                truncate_fn=self._truncate,
                original_score_override=original_score_override,
            )

        final_truncated = self._truncate(final_response)

        # If mislead defense ran, evaluation_score reflects the POST-rewrite score.
        if self.settings.enabled and mislead_info is not None:
            score = mislead_info.get("new_score")
            reason = mislead_info.get("new_reason")
        else:
            score, reason = score_fn(harmful_behavior, final_truncated)

        result: Dict[str, Any] = {
            "attacker": attacker_text,
            "target": final_response,
            "target_truncated": final_truncated,
            "evaluation_score": score,
            "evaluation_reason": reason,
        }
        if proact_info is not None or mislead_info is not None or guard_info is not None:
            combined: Dict[str, Any] = {}
            if mislead_info is not None:
                combined.update(mislead_info)
                combined["mislead_enabled"] = True
            if guard_info is not None:
                combined["guard"] = guard_info
            if proact_info is not None:
                combined["proact_enabled"] = True
                combined["proact"] = proact_info
                combined["raw_target_response"] = raw_target_response
                combined["proact_response"] = original_response_full if proact_info.get("intent") == "malicious" else None
                # If ProAct short-circuited target generation, still record original_score override.
                if proact_used:
                    combined.setdefault("original_score", 1)
                    combined.setdefault("original_reason", "skipped_by_proact")
            result["defense_info"] = combined

        return result

    def attach_defense_stats(self, out: Dict[str, Any]) -> None:
        if not self.settings.enabled:
            return
        dt = self.tracker
        success_rate = (dt.successes / dt.attempts) if dt.attempts > 0 else 0.0
        inc_rate = (dt.increase_successes / dt.increase_attempts) if dt.increase_attempts > 0 else 0.0
        dec_rate = (dt.decrease_successes / dt.decrease_attempts) if dt.decrease_attempts > 0 else 0.0
        out["defense_stats"] = {
            "attempts": dt.attempts,
            "successes": dt.successes,
            "success_rate": success_rate,
            "increase_attempts": dt.increase_attempts,
            "increase_successes": dt.increase_successes,
            "increase_success_rate": inc_rate,
            "decrease_attempts": dt.decrease_attempts,
            "decrease_successes": dt.decrease_successes,
            "decrease_success_rate": dec_rate,
            "score_changes": dict(dt.score_changes),
            "all_defense_info": dt.all_defense_info,
        }


