import random
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from jailbreakbench.llm_wrapper import LLM

from nltk.corpus import wordnet

from .base_defense import Defense
from .defenses_registry import register_defense


@register_defense
class SynonymSubstitution(Defense):
    _DEFENSE_NAME = "SynonymSubstitution"

    def __init__(self, target_model: "LLM"):
        super().__init__(target_model=target_model)

    def query_with_dicts(self, prompt):
        prompt[1]["content"] = self.substitute_synonyms(prompt[1]["content"], q=self.hparams["substitution_q"])
        return self.target_model.query_llm([prompt])[0]

    @staticmethod
    def get_synonyms(word):
        synonyms = []
        for syn in wordnet.synsets(word):
            for lemma in syn.lemmas():
                synonyms.append(lemma.name().replace("_", " "))
        return list(set(synonyms))

    def substitute_synonyms(self, sentence, q):
        words = sentence.split()
        new_sentence = []
        for word in words:
            synonyms = self.get_synonyms(word)
            if synonyms:
                if random.random() <= q / 100:
                    new_sentence.append(synonyms[0])
                else:
                    new_sentence.append(word)
        return " ".join(new_sentence)
