from collections.abc import Sequence
from dataclasses import dataclass

from agentgym.agent_pipeline.base_pipeline_element import (
    BasePipelineElement,
    DummyEnv,
)
from agentgym.agent_pipeline.errors import AbortAgentError
from agentgym.functions_engine.functions_engine import Env, Function
from agentgym.types import ChatMessage


@dataclass
class DetectorTask:
    index: int
    message: ChatMessage
    text: str


class PromptInjectionDetector(BasePipelineElement):
    def __init__(self, mode: str = "message", raise_on_injection: bool = False):
        """
        :param mode:
            The mode in which the detector should operate. It can be 'message' or 'full_conversation'.
            In 'message' mode, the detector is applied per tool output message, while in 'full_conversation'
            mode, the detector is applied to the full conversation so far.
        """
        super().__init__()
        self.mode = mode
        self.raise_on_injection = raise_on_injection

    def detect(self, tool_output: str) -> bool | tuple[bool, float]:
        """
        Returns True if a prompt injection is detected, False otherwise.

        :return:
            If the detector is a probabilistic one, it can return a tuple with the first element being a boolean
            indicating whether a prompt injection was detected, and the second element being a float representing
            the confidence score. If the detector is not probabilistic, it should return a boolean.
        """
        raise NotImplementedError("to use PromptInjectionDetector, use a subclass with a detect method")

    def transform(self, tool_output: str) -> str:
        return "<Data omitted because a prompt injection was detected>"

    def query(
        self,
        query: str,
        tools: Sequence[Function],
        env: Env = DummyEnv(),
        messages: Sequence[ChatMessage] = [],
        extra_args: dict = {},
    ) -> tuple[str, Sequence[Function], Env, Sequence[ChatMessage], dict]:
        if len(messages) == 0:
            return query, tools, env, messages, extra_args
        if messages[-1]["role"] != "tool":
            return query, tools, env, messages, extra_args
        messages_to_check = [DetectorTask(len(messages) - 1, messages[-1], messages[-1]["content"])]
        n_tool_results = 1

        for i, message in reversed(list(enumerate(messages[:-1]))):
            if message["role"] != "tool":
                break
            messages_to_check.append(DetectorTask(i, message, message["content"]))
            n_tool_results += 1

        processed_messages = list(messages[:-n_tool_results])
        for task in messages_to_check:
            message = task.message
            if self.mode == "message":
                text = task.text  # is ["content"] by default
            elif self.mode == "full_conversation":
                previous_messages = [m for m in messages[: task.index + 1]]
                text = "\n".join(
                    [m["role"] + ": " + m["content"] for m in previous_messages if m["content"] is not None]
                )
            else:
                raise ValueError(
                    f"Invalid checking mode: {self.mode} (valid options are 'message' and 'full_conversation')"
                )

            detection_result = self.detect(text)
            score = (
                (1.0 if detection_result else 0.0) if not isinstance(detection_result, tuple) else detection_result[1]
            )
            is_injection = detection_result if not isinstance(detection_result, tuple) else detection_result[0]

            if self.raise_on_injection and is_injection:
                raise AbortAgentError(
                    f"aborting execution because a prompt injection was detected (score: {score}) in message: {text}",
                    list(messages),
                    env,
                )

            if is_injection:
                message["content"] = self.transform(message["content"] or "")
            processed_messages.append(message)

        return query, tools, env, messages, extra_args


class TransformersBasedPIDetector(PromptInjectionDetector):
    def __init__(
        self,
        model_name: str = "protectai/deberta-v3-base-prompt-injection-v2",
        safe_label: str = "SAFE",
        threshold: float = 0.5,
        mode: str = "message",
        raise_on_injection: bool = False,
    ) -> None:
        """
        :param model_name: The name of the model to use for prompt injection detection.
        :param safe_label: The label that indicates a safe prompt.
        :param threshold: The threshold for the model's prediction to be considered a prompt injection, e.g. safety_score < threshold means
                          an input is classified as a prompt injection.
        :param mode: The mode in which the detector should operate. It can be 'message' or 'full_conversation'.
        :param raise_on_injection: Whether to raise an exception if a prompt injection is detected.
        """
        super().__init__(mode=mode, raise_on_injection=raise_on_injection)

        try:
            import torch
            from transformers import pipeline
        except ImportError:
            raise ImportError(
                "TransformersBasedPIDetector requires the 'transformers' and 'torch' libraries to be installed. "
                "Please install them using 'pip install .[transformers]'."
            )

        self.model_name = model_name
        self.safe_label = safe_label
        self.threshold = threshold
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pipeline = pipeline("text-classification", model=self.model_name, device=device)

    def detect(self, tool_output: str) -> tuple[bool, float]:
        result: dict = self.pipeline(tool_output)[0]  # type: ignore
        safety_score: float = result["score"] if result["label"] == self.safe_label else 1 - result["score"]
        is_injection: bool = safety_score < self.threshold

        return is_injection, safety_score
