import copy
from typing import TYPE_CHECKING

import torch
from transformers import AutoTokenizer

if TYPE_CHECKING:
    from jailbreakbench.llm.llm_wrapper import LLM

import os

from jailbreakbench.classifier import LlamaGuard1JailbreakJudge

from .base_defense import Defense
from .defenses_registry import register_defense


@register_defense
class EraseAndCheck(Defense):
    """Erase and Check (Suffix).

    Paper: Certifying LLM Safety against Adversarial Prompting
    Authors: Aounon Kumar, Chirag Agarwal, Suraj Srinivas, Aaron Jiaxun Li,
        Soheil Feizi, and Himabindu Lakkaraju
    Link: https://arxiv.org/pdf/2309.02705
    """

    _DEFENSE_NAME = "EraseAndCheck"
    REFUSAL_RESPONSE = "I'm sorry, I cannot assist with that request."

    def __init__(self, target_model: "LLM"):
        super().__init__(target_model)
        self.tokenizer = AutoTokenizer.from_pretrained(self.hparams["tokenizer_path"])
        self.judge = LlamaGuard1JailbreakJudge(os.environ["TOGETHER_API_KEY"]).classify_responses

    def refusal_response(self, prompt):
        """Returns a refusal response."""

        refusal_response = self.target_model.query_llm([prompt])

        # Update the text in the response
        refusal_response.responses[0] = self.REFUSAL_RESPONSE

        # Update the number of response tokens
        num_completion_tokens = len(self.tokenize(self.REFUSAL_RESPONSE, as_pt=False))
        refusal_response.completion_tokens[0] = num_completion_tokens

        return refusal_response[0]

    def tokenize(self, sequence: str, as_pt: bool = True) -> torch.Tensor:
        """Return the input_ids of the input sequence."""

        if not as_pt:
            return self.tokenizer.encode(sequence)
        return self.tokenizer.encode(sequence, return_tensors="pt").cuda()

    def query_with_dicts(self, prompt: list[dict[str, str]]) -> tuple[str, int, int]:
        """Filter prompt based on log perplexity."""

        prompt_string = prompt[1]["content"]

        prompt_tokens = self.tokenize(prompt_string)[0][1:]
        prompt_length = len(prompt_tokens)

        prompt_list = [prompt]
        for i in range(min(self.hparams["erase_length"], prompt_length)):
            erased_prompt_tokens = prompt_tokens[: -(i + 1)]
            erased_prompt = copy.deepcopy(prompt)
            erased_prompt[1]["content"] = self.tokenizer.decode(erased_prompt_tokens)
            prompt_list.append(erased_prompt)

        responses = self.target_model.query_llm(prompt_list)
        string_prompts = [p[1]["content"] for p in prompt_list]
        are_copies_jailbroken = self.judge(string_prompts, responses.responses)

        if any(are_copies_jailbroken):
            return self.refusal_response(prompt)
        return self.target_model.query_llm([prompt])[0]
