import unittest

from jailbreakbench.config import CHAT_TEMPLATES, HF_MODEL_NAMES, LITELLM_CUSTOM_PROMPTS, SYSTEM_PROMPTS, Model
from litellm.llms.prompt_templates.factory import custom_prompt
from transformers import AutoTokenizer

# From https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#example-prompt-weights-v11-v13-v15
EXPECTED_VICUNA_OUTPUT = """\
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Hello! ASSISTANT: Hello!</s>USER: How are you? ASSISTANT:"""


VICUNA_MESSAGES = [
    {
        "role": "system",
        "content": SYSTEM_PROMPTS[Model("vicuna-13b-v1.5")],
    },
    {"role": "user", "content": "Hello!"},
    {"role": "assistant", "content": "Hello!"},
    {"role": "user", "content": "How are you?"},
]


# From https://huggingface.co/blog/llama2#how-to-prompt-llama-2 (I had to remove a space before "Your answers" in the
# system prompt from the original blogpost)
EXPECTED_LLAMA_OUTPUT = """\
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

There's a llama in my garden 😱 What should I do? [/INST]"""


LLAMA_MESSAGES = [
    {
        "role": "system",
        "content": SYSTEM_PROMPTS[Model("llama-2-7b-chat-hf")],
    },
    {"role": "user", "content": "There's a llama in my garden 😱 What should I do?"},
]


MESSAGES = {Model.vicuna: VICUNA_MESSAGES, Model.llama_2: LLAMA_MESSAGES}


EXPECTED_OUTPUTS = {Model.vicuna: EXPECTED_VICUNA_OUTPUT, Model.llama_2: EXPECTED_LLAMA_OUTPUT}


class TestConfig(unittest.TestCase):
    def setUp(self):
        self.models = [Model.llama_2, Model.vicuna]
        self.maxDiff = None

    def test_hf_chat_templates(self):
        for model in self.models:
            tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAMES[model])
            if model in CHAT_TEMPLATES:
                tokenizer.chat_template = CHAT_TEMPLATES[model]
            prompt = tokenizer.apply_chat_template(MESSAGES[model], tokenize=False, add_generation_prompt=True)
            self.assertEqual(prompt, EXPECTED_OUTPUTS[model])

    def test_litellm_chat_templates(self):
        for model in self.models:
            role_dict = LITELLM_CUSTOM_PROMPTS[model]
            prompt = custom_prompt(
                role_dict=role_dict,
                messages=MESSAGES[model],
            )
            self.assertEqual(prompt, EXPECTED_OUTPUTS[model])
