from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal

from agentdojo.agent_pipeline.base_pipeline_element import BasePipelineElement
from agentdojo.agent_pipeline.errors import AbortAgentError
from agentdojo.functions_runtime import EmptyEnv, Env, FunctionsRuntime
from agentdojo.types import ChatMessage, MessageContentBlock, get_text_content_as_str, text_content_block_from_string


@dataclass
class DetectorTask:
    index: int
    message: ChatMessage
    text: str


class PromptInjectionDetector(BasePipelineElement):
    """Detects whether a prompt injection is in the tools outputs.

    Args:
        mode: The mode in which the detector should operate. It can be "message" or "full_conversation".
            In "message" mode, the detector is applied per tool output message, while in "full_conversation"
            mode, the detector is applied to the full conversation so far.
        raise_on_injection: Whether to raise a [AbortAgentError][agentdojo.agent_pipeline.AbortAgentError] if
            an injection is detected.
    """

    def __init__(self, mode: Literal["message", "full_conversation"] = "message", raise_on_injection: bool = False):
        super().__init__()
        self.mode = mode
        self.raise_on_injection = raise_on_injection

    def detect(self, tool_output: str) -> bool | tuple[bool, float]:
        """
        Returns True if a prompt injection is detected, False otherwise.

        Raises:
            AbortAgentError: if an injection is detected and if `self.raise_on_injection` is `True`.

        Returns:
            If the detector is a probabilistic one, it can return a tuple with the first element being a boolean
                indicating whether a prompt injection was detected, and the second element being a float representing
                the confidence score. If the detector is not probabilistic, it should return a boolean.
        """
        raise NotImplementedError("to use PromptInjectionDetector, use a subclass with a detect method")

    def transform(self, tool_output: list[MessageContentBlock]) -> list[MessageContentBlock]:
        return [
            text_content_block_from_string("<Data omitted because a prompt injection was detected>")
            if block["type"] == "text"
            else block
            for block in tool_output
        ]

    def query(
        self,
        query: str,
        runtime: FunctionsRuntime,
        env: Env = EmptyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, FunctionsRuntime, Env, Sequence[ChatMessage], dict]:
        if len(messages) == 0:
            return query, runtime, env, messages, extra_args
        if messages[-1]["role"] != "tool":
            return query, runtime, env, messages, extra_args
        messages_to_check = [
            DetectorTask(len(messages) - 1, messages[-1], get_text_content_as_str(messages[-1]["content"]) or "")
        ]
        n_tool_results = 1

        for i, message in reversed(list(enumerate(messages[:-1]))):
            if message["role"] != "tool":
                break
            messages_to_check.append(DetectorTask(i, message, get_text_content_as_str(messages[-1]["content"]) or ""))
            n_tool_results += 1

        processed_messages = list(messages[:-n_tool_results])
        for task in messages_to_check:
            message = task.message
            if self.mode == "message":
                text = task.text  # is ["content"] by default
            elif self.mode == "full_conversation":
                previous_messages = [m for m in messages[: task.index + 1]]
                text = "\n".join(
                    [
                        f"{m['role']}: {get_text_content_as_str(m['content'])}"
                        for m in previous_messages
                        if m["content"] is not None
                    ]
                )
            else:
                raise ValueError(
                    f"Invalid checking mode: {self.mode} (valid options are 'message' and 'full_conversation')"
                )

            detection_result = self.detect(text)
            score = (
                (1.0 if detection_result else 0.0) if not isinstance(detection_result, tuple) else detection_result[1]
            )
            is_injection = detection_result if not isinstance(detection_result, tuple) else detection_result[0]

            if self.raise_on_injection and is_injection:
                raise AbortAgentError(
                    f"aborting execution because a prompt injection was detected (score: {score}) in message: {text}",
                    list(messages),
                    env,
                )

            if is_injection:
                message["content"] = self.transform(message["content"] or [])
            processed_messages.append(message)

        return query, runtime, env, messages, extra_args


class TransformersBasedPIDetector(PromptInjectionDetector):
    """Uses [`protectai/deberta-v3-base-prompt-injection-v2`](https://huggingface.co/protectai/deberta-v3-base-prompt-injection-v2)
    to detect prompt injections.

    Args:
        model_name: The name of the model to use for prompt injection detection.
        safe_label: The label that indicates a safe prompt.
        threshold: The threshold for the model's prediction to be considered a prompt injection, e.g. safety_score < threshold means
            an input is classified as a prompt injection.
        mode: The mode in which the detector should operate. It can be 'message' or 'full_conversation'.
        raise_on_injection: Whether to raise an exception if a prompt injection is detected.
    """

    def __init__(
        self,
        model_name: str = "protectai/deberta-v3-base-prompt-injection-v2",
        safe_label: str = "SAFE",
        threshold: float = 0.5,
        mode: Literal["message", "full_conversation"] = "message",
        raise_on_injection: bool = False,
    ) -> None:
        super().__init__(mode=mode, raise_on_injection=raise_on_injection)

        try:
            import torch  # type: ignore
            from transformers import pipeline  # type: ignore
        except ImportError:
            raise ImportError(
                "TransformersBasedPIDetector requires the 'transformers' and 'torch' libraries to be installed. "
                "Please install them using 'pip install agentdojo[transformers]'."
            )

        self.model_name = model_name
        self.safe_label = safe_label
        self.threshold = threshold
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pipeline = pipeline("text-classification", model=self.model_name, device=device)

    def detect(self, tool_output: str) -> tuple[bool, float]:
        result: dict = self.pipeline(tool_output)[0]  # type: ignore
        safety_score: float = result["score"] if result["label"] == self.safe_label else 1 - result["score"]
        is_injection: bool = safety_score < self.threshold

        return is_injection, safety_score


class GrokTemplatedPIDetector(TransformersBasedPIDetector):
    """A TransformersBasedPIDetector that wraps tool outputs with a template instead of replacing them.

    Args:
        template_prefix: The prefix to add before the tool output.
        template_suffix: The suffix to add after the tool output.
        model_name: The name of the model to use for prompt injection detection.
        safe_label: The label that indicates a safe prompt.
        threshold: The threshold for the model's prediction to be considered a prompt injection.
        mode: The mode in which the detector should operate.
        raise_on_injection: Whether to raise an exception if a prompt injection is detected.
    """

    def __init__(
        self,
        template_prefix: str = "<|separator|>\n\nSystem: ",
        template_suffix: str = "<|separator|>",
        model_name: str = "protectai/deberta-v3-base-prompt-injection-v2",
        safe_label: str = "SAFE",
        threshold: float = 0.5,
        mode: Literal["message", "full_conversation"] = "message",
        raise_on_injection: bool = False,
    ) -> None:
        super().__init__(
            model_name=model_name,
            safe_label=safe_label,
            threshold=threshold,
            mode=mode,
            raise_on_injection=raise_on_injection,
        )
        self.template_prefix = template_prefix
        self.template_suffix = template_suffix

    def transform(self, tool_output: list[MessageContentBlock]) -> list[MessageContentBlock]:
        """Transform tool output by wrapping text content with the template."""
        return [
            text_content_block_from_string(f"{self.template_prefix}{block['content']}{self.template_suffix}")
            if block["type"] == "text"
            else block
            for block in tool_output
        ]
