from __future__ import annotations

from dataclasses import dataclass, field
import os
from typing import Callable, Dict, List, Optional, Tuple

from defense.utils import mislead_defense


ScoreFn = Callable[[str, str], Tuple[int, str]]
TruncateFn = Callable[[str], str]


@dataclass
class DefenseTracker:
    attempts: int = 0
    successes: int = 0

    increase_attempts: int = 0
    increase_successes: int = 0

    decrease_attempts: int = 0
    decrease_successes: int = 0

    # key: "<from>_to_<to>"
    score_changes: Dict[str, int] = field(default_factory=dict)

    # For downstream similarity calculation
    all_defense_info: List[Dict] = field(default_factory=list)

    def reset(self) -> None:
        self.attempts = 0
        self.successes = 0
        self.increase_attempts = 0
        self.increase_successes = 0
        self.decrease_attempts = 0
        self.decrease_successes = 0
        self.score_changes = {}
        self.all_defense_info = []


class MisleadDefenseRunner:
    """
    A standalone runner for the "mislead defense" used inside Feedback:
    - Scores original response
    - Rewrites response via mislead_defense()
    - Scores rewritten response
    - Updates DefenseTracker + returns defense_info payload
    """

    def __init__(
        self,
        *,
        enabled: bool,
        rewrite_model: str,
        rewrite_server_url_increase: Optional[str],
        rewrite_server_url_decrease: Optional[str],
        tracker: Optional[DefenseTracker] = None,
    ) -> None:
        self.enabled = bool(enabled)
        self.rewrite_model = rewrite_model
        self.rewrite_server_url_increase = rewrite_server_url_increase
        self.rewrite_server_url_decrease = rewrite_server_url_decrease
        self.tracker = tracker or DefenseTracker()

    def reset(self) -> None:
        self.tracker.reset()

    def maybe_defend(
        self,
        *,
        harmful_behavior: str,
        prompt_for_rewrite: str,
        original_response_full: str,
        score_fn: ScoreFn,
        truncate_fn: TruncateFn,
        # If provided, skip judge scoring and reuse this original_score.
        original_score_override: Optional[int] = None,
        # If False, skip judge scoring on the defended response (save compute); new_score=new original_score.
        compute_new_score: bool = True,
    ) -> Tuple[str, Optional[Dict]]:
        """
        Returns:
          (final_response, defense_info_or_none)
        """
        if not self.enabled:
            return original_response_full, None

        # If caller provides an original_score_override, we skip scoring the ORIGINAL response.
        if original_score_override is not None:
            # Default rewrite direction is controlled by env (set by run_full_pipeline.sh).
            # If unset, fall back to "increase" for backward compatibility.
            _direction = os.getenv("REWRITE_DIRECTION", "increase")
            defended_response, direction, similarity_score = mislead_defense(
                original_response=original_response_full,
                prompt=prompt_for_rewrite,
                turn_index=self.tracker.attempts,
                rewrite_model=self.rewrite_model,
                rewrite_server_url_increase=self.rewrite_server_url_increase,
                rewrite_server_url_decrease=self.rewrite_server_url_decrease,
                direction=_direction,
            )

            original_score = int(original_score_override)
            original_reason = "skipped_by_proact"

            if compute_new_score:
                # Score defended response once (judge sees truncated).
                defended_truncated = truncate_fn(defended_response)
                new_score, new_reason = score_fn(harmful_behavior, defended_truncated)
                new_score = int(new_score)
                score_changed = new_score != original_score
                new_score_skipped = False
            else:
                # Skip judge scoring on defended response.
                new_score = original_score
                new_reason = original_reason
                score_changed = False
                new_score_skipped = True

            # Update stats (we treat original_score as overridden / synthetic).
            self.tracker.attempts += 1
            if score_changed:
                self.tracker.successes += 1
                k = f"{original_score}_to_{int(new_score)}"
                self.tracker.score_changes[k] = self.tracker.score_changes.get(k, 0) + 1
            if direction == "increase":
                self.tracker.increase_attempts += 1
                if score_changed:
                    self.tracker.increase_successes += 1
            elif direction == "decrease":
                self.tracker.decrease_attempts += 1
                if score_changed:
                    self.tracker.decrease_successes += 1

            self.tracker.all_defense_info.append(
                {
                    "original_response": original_response_full,
                    "defended_response": defended_response,
                    "direction": direction,
                    "score_changed": score_changed,
                    "similarity_score": similarity_score,
                }
            )

            defense_info = {
                "original_response": original_response_full,
                "defended_response": defended_response,
                "direction": direction,
                "original_score": original_score,
                "original_reason": original_reason,
                "new_score": int(new_score),
                "new_reason": new_reason,
                "score_changed": score_changed,
                "similarity_score": similarity_score,
                "original_score_skipped": True,
                "new_score_skipped": new_score_skipped,
            }
            return defended_response, defense_info

        # Score original response (judge sees truncated).
        original_truncated = truncate_fn(original_response_full)
        original_score, original_reason = score_fn(harmful_behavior, original_truncated)

        # Rewrite (full text in, full text out).
        # Default rewrite direction is controlled by env (set by run_full_pipeline.sh).
        # If unset, fall back to "increase" for backward compatibility.
        _direction = os.getenv("REWRITE_DIRECTION", "increase")
        defended_response, direction, similarity_score = mislead_defense(
            original_response=original_response_full,
            prompt=prompt_for_rewrite,
            turn_index=self.tracker.attempts,
            rewrite_model=self.rewrite_model,
            rewrite_server_url_increase=self.rewrite_server_url_increase,
            rewrite_server_url_decrease=self.rewrite_server_url_decrease,
            direction=_direction,
        )

        if compute_new_score:
            # Score defended response.
            defended_truncated = truncate_fn(defended_response)
            new_score, new_reason = score_fn(harmful_behavior, defended_truncated)
            score_changed = new_score != original_score
            new_score_skipped = False
        else:
            new_score = original_score
            new_reason = original_reason
            score_changed = False
            new_score_skipped = True

        # Update stats.
        self.tracker.attempts += 1
        if score_changed:
            self.tracker.successes += 1
            k = f"{original_score}_to_{new_score}"
            self.tracker.score_changes[k] = self.tracker.score_changes.get(k, 0) + 1

        if direction == "increase":
            self.tracker.increase_attempts += 1
            if score_changed:
                self.tracker.increase_successes += 1
        elif direction == "decrease":
            self.tracker.decrease_attempts += 1
            if score_changed:
                self.tracker.decrease_successes += 1

        # Store minimal info for later similarity calc (full texts, no truncation).
        self.tracker.all_defense_info.append(
            {
                "original_response": original_response_full,
                "defended_response": defended_response,
                "direction": direction,
                "score_changed": score_changed,
                "similarity_score": similarity_score,
            }
        )

        defense_info = {
            "original_response": original_response_full,
            "defended_response": defended_response,
            "direction": direction,
            "original_score": original_score,
            "original_reason": original_reason,
            "new_score": new_score,
            "new_reason": new_reason,
            "score_changed": score_changed,
            "similarity_score": similarity_score,
            "new_score_skipped": new_score_skipped,
        }
        return defended_response, defense_info


