import re
from typing import TYPE_CHECKING

from wordfreq import iter_wordlist

if TYPE_CHECKING:
    from jailbreakbench.llm_wrapper import LLM

from .base_defense import Defense
from .defenses_registry import register_defense


@register_defense
class RemoveNonDictionary(Defense):
    _DEFENSE_NAME = "RemoveNonDictionary"

    def __init__(self, target_model: "LLM"):
        super().__init__(target_model=target_model)

    def query_with_dicts(self, prompt):
        prompt_content = prompt[1]["content"]
        prompt[1]["content"] = self.remove_non_dictionary_words(prompt_content)
        return self.target_model.query_llm([prompt])[0]

    @staticmethod
    def remove_non_dictionary_words(text):
        english_list = list(iter_wordlist("en", wordlist="best"))

        # Split the input text into words, preserving capitalization
        tokens = re.findall(r"\w+|[^\w\s]", text)
        filtered_tokens = []
        for token in tokens:
            if re.match(r"\w+", token):
                if token.lower() in english_list:
                    filtered_tokens.append(token)
            else:
                # Token is a punctuation mark
                filtered_tokens.append(token)

        # Join the remaining words back into a string
        filtered_text = " ".join(filtered_tokens)
        filtered_text = filtered_text.replace(" ,", ",")
        filtered_text = filtered_text.replace(" .", ".")
        filtered_text = filtered_text.replace(" ;", ";")
        filtered_text = filtered_text.replace(" :", ":")
        filtered_text = filtered_text.replace(" ' s", "'s")
        filtered_text = filtered_text.replace("' s", "'s")

        return filtered_text
