import abc
from typing import TYPE_CHECKING

from tqdm import tqdm

from jailbreakbench.defenses.defenselib import defense_hparams
from jailbreakbench.llm.llm_output import LLMOutput

if TYPE_CHECKING:
    from jailbreakbench.llm.llm_wrapper import LLM


class Defense(abc.ABC):
    """Base class for jailbreaking defenses."""

    def __init__(self, target_model: "LLM"):
        """Constructor for defense base class.

        Args:
            target_model: Instance of LLMvLLM or LLMLiteLLM.
                Query the target model via target_model.query_llm()
            defense: Name of defense subclass.
        """

        self.target_model = target_model
        self.hparams = defense_hparams.register_hparams(type(self))

    def query(self, prompts: list[str]) -> LLMOutput:
        """Query the defense with the supplied input prompt;
        return response with input and output tokens.

        Args:
            prompts: List of input prompts.

        Returns:
            LLMOutput object containing the response, input tokens, and output tokens.
        """

        convs_list = [self.target_model.convert_prompt_to_conversation(p) for p in prompts]
        return LLMOutput.from_list([self.query_with_dicts(p) for p in tqdm(convs_list)])

    @abc.abstractmethod
    def query_with_dicts(self, prompt: list[dict[str, str]]) -> tuple[str, int, int]:
        """Query the defense with the supplied input prompt;
        return response with input and output tokens.
        Subclass should override.

        Takes as input the input in the following form:
        ```
        [{"role": "system", "content": "lorem ipsum"}, {"role": "user", "content": "dolor sit amet"}]
        ```

        Returns the output in the following form:
        ```
        (defended model response, prompt_tokens, completion_tokens)
        ```

        Where `prompt_tokens` and `completion_tokens` are the number of tokens respectively in the prompt and completion,
        as returned by target_model.query(...).
        """

        raise NotImplementedError()
