from typing import TYPE_CHECKING

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from jailbreakbench import dataset

if TYPE_CHECKING:
    from jailbreakbench.llm.llm_wrapper import LLM

from .base_defense import Defense
from .defenses_registry import register_defense


@register_defense
class PerplexityFilter(Defense):
    """Perplexity filter.

    Paper: Baseline Defenses for Adversarial Attacks Against Aligned Language Models
    Authors: Neel Jain, Avi Schwarzschild, Yuxin Wen, Gowthami Somepalli, John Kirchenbauer,
        Ping-yeh Chiang, Micah Goldblum, Aniruddha Saha, Jonas Geiping, Tom Goldstein
    Link: https://arxiv.org/abs/2309.00614
    """

    _DEFENSE_NAME = "PerplexityFilter"
    REFUSAL_RESPONSE = "I'm sorry, I cannot assist with that request."

    def __init__(self, target_model: "LLM"):
        super().__init__(target_model)
        self.ppl_tokenizer = AutoTokenizer.from_pretrained(self.hparams["perplexity_model_path"])
        self.ppl_model = AutoModelForCausalLM.from_pretrained(self.hparams["perplexity_model_path"]).cuda()
        self.cn_loss = torch.nn.CrossEntropyLoss(reduction="none")

        goals = dataset.read_dataset().goals
        self.hparams["threshold"] = self.max_log_perplexity(goals)

    def refusal_response(self, prompt):
        """Returns a refusal response."""

        refusal_response = self.target_model.query_llm([prompt])

        # Update the text in the response
        refusal_response.responses[0] = self.REFUSAL_RESPONSE

        # Update the number of response tokens
        num_completion_tokens = len(self.tokenize(self.REFUSAL_RESPONSE, as_pt=False))
        refusal_response.completion_tokens[0] = num_completion_tokens

        return refusal_response[0]

    def tokenize(self, sequence: str, as_pt: bool = True) -> torch.Tensor:
        """Return the input_ids of the input sequence."""

        if not as_pt:
            return self.ppl_tokenizer.encode(sequence)
        return self.ppl_tokenizer.encode(sequence, return_tensors="pt").cuda()

    def log_perplexity(self, sequence: str) -> float:
        """Get the log perplexity of a sequence."""

        input_ids = self.tokenize(sequence)
        with torch.no_grad():
            loss = self.ppl_model(input_ids, labels=input_ids).loss
        return loss.item()

    def max_log_perplexity(self, sequences: list[str]) -> float:
        """Get maximum log perplexity of a sequence."""

        return max([self.log_perplexity(s) for s in sequences])

    def log_prob(self, sequence: str):
        """Get the log probabilities of the token."""

        input_ids = self.tokenize(sequence)
        with torch.no_grad():
            logits = self.ppl_model(input_ids, labels=input_ids).logits
        logits = logits[:, :-1, :].contiguous()
        input_ids = input_ids[:, 1:].contiguous()
        log_probs = self.cn_loss(logits.view(-1, logits.size(-1)), input_ids.view(-1))
        return log_probs

    def query_with_dicts(self, prompt: list[dict[str, str]]) -> tuple[str, int, int]:
        """Filter prompt based on log perplexity."""

        prompt_string = prompt[1]["content"]

        log_probs = self.log_prob(prompt_string)
        if log_probs.mean() <= self.hparams["threshold"]:
            return self.target_model.query_llm([prompt])[0]
        return self.refusal_response(prompt)
