import json
import os
import unittest

from jailbreakbench.llm.llm_wrapper import LLMDummy


class TestLLM(unittest.TestCase):
    def setUp(self):
        self.model = LLMDummy("llama-2-7b-chat-hf")
        self.prompts = ["What is 1+1?", "What color is an apple?"]
        self.behavior = "unittests"
        self.phase = "dev"

    def test_query(self):
        responses = self.model.query(self.prompts, self.behavior, self.phase).responses
        correct_responses = ["Response to `What is 1+1?`", "Response to `What color is an apple?`"]
        self.assertEqual(correct_responses, responses)

    def test_json_file_generation(self):
        file_path = f"logs/{self.phase}/{self.model.model_name.value}/{self.behavior}.json"
        # If the file already exists, delete it
        if os.path.exists(file_path):
            os.remove(file_path)
        self.model.query(self.prompts, self.behavior, self.phase)
        self.assertTrue(os.path.exists(file_path))

        with open(file_path) as file:
            data = json.load(file)
            query_dict = data[0]
            self.assertEqual(len(query_dict), 2)
            llm_outputs = query_dict[1]
            for i in range(len(self.prompts)):
                self.assertEqual(len(llm_outputs[i]), 4)  # prompt, response, prompt_tokens, response_tokens
                self.assertEqual(llm_outputs[i]["prompt"], self.prompts[i])
                self.assertEqual(llm_outputs[i]["response"], f"Response to `{self.prompts[i]}`")

    def tearDown(self):
        file_path = f"logs/{self.phase}/{self.model.model_name.value}/{self.behavior}.json"
        if os.path.exists(file_path):
            os.remove(file_path)


if __name__ == "__main__":
    unittest.main()
