import copy
import os
import random
from typing import TYPE_CHECKING

import numpy as np

from jailbreakbench.classifier import LlamaGuard1JailbreakJudge

if TYPE_CHECKING:
    from jailbreakbench.llm.llm_wrapper import LLM

from .base_defense import Defense
from .defenselib import perturbations
from .defenses_registry import register_defense


@register_defense
class SmoothLLM(Defense):
    """SmoothLLM.

    Paper: SmoothLLM: Defending Large Language Models Against Jailbreaking Attacks
    Authors: Alexander Robey, Eric Wong, Hamed Hassani, George J. Pappas
    Link: https://arxiv.org/abs/2310.03684
    """

    def __init__(self, target_model: "LLM"):
        super().__init__(target_model=target_model)

        perturbation_type = self.hparams["perturbation_type"]
        self.perturbation_fn = vars(perturbations)[perturbation_type](q=self.hparams["perturbation_pct"])
        self.judge = LlamaGuard1JailbreakJudge(os.environ["TOGETHER_API_KEY"]).classify_responses

    def query_with_dicts(self, prompt: list[dict[str, str]]) -> tuple[str, int, int]:
        # Perturb the input prompts and obtain responses
        perturbed_prompts = self.perturb_copies(prompt)
        responses = self.target_model.query_llm(perturbed_prompts)

        # Extract prompt/response strings from conversation templates
        user_prompts = [p[1]["content"] for p in perturbed_prompts]

        # Determine whether each perturbed copy is jailbroken
        are_copies_jailbroken = self.judge(user_prompts, responses.responses)

        # Determine whether the majority of responses were jailbroken
        jb_percentage = np.mean(are_copies_jailbroken)
        is_smoothllm_jailbroken = True if jb_percentage > 0.5 else False

        # Randomly select a response consistent with the majority vote
        return random.choice(
            [r for (r, jb) in zip(responses.iterate_elements(), are_copies_jailbroken) if jb == is_smoothllm_jailbroken]
        )

    def perturb_copies(self, prompt):
        """Create perturbed copies of the supplied input prompt."""

        perturbed_prompts = []
        for _ in range(self.hparams["number_of_copies"]):
            prompt_copy = copy.deepcopy(prompt)
            prompt_copy[1]["content"] = self.perturbation_fn(prompt_copy[1]["content"])
            perturbed_prompts.append(prompt_copy)

        return perturbed_prompts
