import tempfile
import unittest
from types import SimpleNamespace
from unittest.mock import Mock

import torch
from transformers import AutoTokenizer

from src.prompt_functions import (
    api_bank_get_prompt_for_strict_api,
    api_bank_get_prompt_for_template_summarize,
    get_prompt_for_strict_api,
    get_prompt_for_template_summarize,
    query_openai,
    query_gemini,
    prompts_to_string,
    get_hf_llm_result,
    prepare_source_and_target,
)
from src.tool_ace_prompt_functions import (
    tool_ace_get_prompt_for_strict_api,
    tool_ace_get_prompt_for_template_summarize,
)
from src.utility import (
    Record,
    default_openai_model_name,
    default_temperature,
    default_top_p,
    ai_role,
    default_gemini_model_name,
    user_role,
    chat_format_template,
    chat_format_general,
    DataMode,
)
from src.when2call_prompt_functions import (
    when2call_get_prompt_for_strict_api,
    when2call_get_prompt_for_template_summarize,
)

api_bank_record_dict = {
    "id": "1",
    "data_set": "api_bank",
    "pre_api": [
        "Generate an API request in the format of [ApiName(key1='value1', key2='value2', ...)] based on the previous dialogue context.",
        "The current year is 2023.",
        "Input:",
        "User: User's utterance",
        "AI: AI's response",
        "Expected output:",
        "API-Request: [ApiName(key1='value1', key2='value2', ...)]",
        "API descriptions:",
    ],
    "api_def": [
        '{"name": "API_1", "input_parameters": {"symptom": {"type": "str"}}, "output_parameters": {"results": {"type": "list"}}}',
    ],
    "conversation": [
        "User: Can you help me find out about shortness of breath?",
    ],
    "ending": ["Generate API Request:"],
    "output": "API-Request: [API_1(symptom='shortness of breath')]",
    "post_api": [],
    "api_call": {
        "api_call_status": "tool_call",
        "api_calls": [
            {
                "api_name": "API_1",
                "params": {"symptom": "shortness of breath"},
            }
        ],
    },
    "template_output": [
        "Call the `API_1` API with following parameter: `symptom` as `shortness of breath`"
    ],
}

tool_ace_record_dict = {
    "id": "2",
    "data_set": "tool_ace",
    "pre_api": [
        "You are an expert in composing functions. You are given a question and a set of possible functions.",
        "Based on the question, you will need to make one or more function/tool calls to achieve the purpose.",
        "If none of the function can be used, point it out. If the given question lacks the parameters required by the function,",
        "also point it out.",
        "Here is a list of functions in JSON format that you can invoke:",
    ],
    "api_def": [
        '[{"name": "Get All Servers", "parameters": {"type": "dict", "properties": {"limit": {"type": "int"}}, "required": ["limit"]}, "required": null}].'
    ],
    "conversation": [
        "user: Could you help me find some popular Minecraft servers to join? Let's set the limit to five servers."
    ],
    "ending": [],
    "output": "[Get All Servers(limit=5)]",
    "post_api": [
        "Should you decide to return the function call(s).",
        "Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]",
        "NO other text MUST be included.",
    ],
    "api_call": {
        "api_call_status": "tool_call",
        "api_calls": [
            {"api_name": "Get All Servers", "params": {"limit": 5}, "responses": None}
        ],
    },
    "template_output": [
        "Call the `Get All Servers` API with following parameters: `limit` as `5`"
    ],
}

when2call_record_dict = {
    "id": "3",
    "data_set": "when2call",
    "pre_api": [],
    "api_def": [
        '{"name": "options_stock", "description": "Fetch option data for a specific stock, ETF, or index.", "parameters": {"type": "dict", "properties": {"symbol": {"description": "The stock symbol to retrieve option data for. Default is \'AAPL\'.", "type": "str", "default": "AAPL"}, "expiration": {"description": "The expiration date for the options in UNIX timestamp format. Default is \'1705622400\'.", "type": "str", "default": "1705622400"}}}, "required": ["symbol", "expiration"]}',
    ],
    "conversation": [
        "Fetch option data for Tesla (TSLA) with an expiration date of 1672502400."
    ],
    "ending": [],
    "output": '{"name": "options_stock", "arguments": {"symbol": "TSLA", "expiration": "1672502400"}}',
    "post_api": [],
    "api_call": {
        "api_call_status": "tool_call",
        "api_calls": [
            {
                "api_name": "options_stock",
                "params": {"symbol": "TSLA", "expiration": "1672502400"},
            }
        ],
    },
    "template_output": [
        "Call the `options_stock` API with following parameters: `symbol` as `TSLA`, `expiration` as `1672502400`"
    ],
}
source_1 = 'Generate an API request in the format of [ApiName(key1=\'value1\', key2=\'value2\', ...)] based on the previous dialogue context.\nThe current year is 2023.\nInput:\nUser: User\'s utterance\nAI: AI\'s response\nExpected output:\nAPI-Request: [ApiName(key1=\'value1\', key2=\'value2\', ...)]\nAPI descriptions:\n{"name": "API_1", "input_parameters": {"symptom": {"type": "str"}}, "output_parameters": {"results": {"type": "list"}}}\nUser: Can you help me find out about shortness of breath?\nGenerate API Request:\n'
source_2 = 'Summarize the next action to take based on conversation history.\nIf the action can be fulfilled with API in API descriptions, summarized result should contain all necessary information defined in corresponding API descriptions.\nSummarize in the format like following: \nIf there\'s no parameter: Call the `API Name` API with no parameter\nIf there\'s one or more parameters: Call the `API Name` API with following parameters: `parameter1 name` as `parameter1 value`, ...\nThe current year is 2023.\nInput:\nUser: User\'s utterance\nAI: AI\'s response\nExpected output:\nA sentence describing next action with all necessary information to call API\nAPI descriptions:\n{"name": "API_1", "input_parameters": {"symptom": {"type": "str"}}, "output_parameters": {"results": {"type": "list"}}}\nUser: Can you help me find out about shortness of breath?\nSummarize the next action:\n'
source_3 = '<s>[INST] You are an expert in composing functions. You are given a question and a set of possible functions.\nBased on the question, you will need to make one or more function/tool calls to achieve the purpose.\nIf none of the function can be used, point it out. If the given question lacks the parameters required by the function,\nalso point it out.\nHere is a list of functions in JSON format that you can invoke:\n[{"name": "Get All Servers", "parameters": {"type": "dict", "properties": {"limit": {"type": "int"}}, "required": ["limit"]}, "required": null}].\nShould you decide to return the function call(s).\nPut it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]\nNO other text MUST be included.\n\nCould you help me find some popular Minecraft servers to join? Let\'s set the limit to five servers.[/INST]'
source_4 = '<s>[INST] You are an expert in composing functions. You are given a question and a set of possible functions.\nBased on the question, summarize the next action to take.\nSummarized result should contain all necessary information defined in corresponding functions.\nSummarize in the format like following: \nIf there\'s no parameter: Call the `function name` API with no parameter\nIf there\'s one or more parameters: Call the `function name` API with following parameters: `parameter1 name` as `parameter1 value`, ...\nIf there\'s multiple function call, use \'\n\' to connect the results, for example: \'Call the `function name` API with ...\nCall the `function name` API with ...\'\nIf none of the function can be used, point it out. If the given question lacks the parameters required by the function,\nalso point it out.\nHere is a list of functions in JSON format that you can invoke:\n[{"name": "Get All Servers", "parameters": {"type": "dict", "properties": {"limit": {"type": "int"}}, "required": ["limit"]}, "required": null}].\n\nCould you help me find some popular Minecraft servers to join? Let\'s set the limit to five servers.[/INST]'
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token


class PromptFunctionsTestCase(unittest.TestCase):
    def test_get_prompt_for_strict_api(self):
        record = Record.from_dict(api_bank_record_dict)
        result = get_prompt_for_strict_api(
            record=record, prepare_for_chat_template=True
        )
        result1 = api_bank_get_prompt_for_strict_api(
            record=record, prepare_for_chat_template=True
        )
        self.assertEqual(result, result1)

        record = Record.from_dict(tool_ace_record_dict)
        result = get_prompt_for_strict_api(
            record=record, prepare_for_chat_template=True
        )
        result1 = tool_ace_get_prompt_for_strict_api(
            record=record, prepare_for_chat_template=True
        )
        self.assertEqual(result, result1)

        record = Record.from_dict(when2call_record_dict)
        result = get_prompt_for_strict_api(
            record=record, prepare_for_chat_template=True
        )
        result1 = when2call_get_prompt_for_strict_api(
            record=record, prepare_for_chat_template=True
        )
        self.assertEqual(result, result1)

    def test_get_prompt_for_template_summarize(self):
        record = Record.from_dict(api_bank_record_dict)
        result = get_prompt_for_template_summarize(
            record=record, prepare_for_chat_template=True
        )
        result1 = api_bank_get_prompt_for_template_summarize(
            record=record, prepare_for_chat_template=True
        )
        self.assertEqual(result, result1)

        record = Record.from_dict(tool_ace_record_dict)
        result = get_prompt_for_template_summarize(
            record=record, prepare_for_chat_template=True
        )
        result1 = tool_ace_get_prompt_for_template_summarize(
            record=record, prepare_for_chat_template=True
        )
        self.assertEqual(result, result1)

        record = Record.from_dict(when2call_record_dict)
        result = get_prompt_for_template_summarize(
            record=record, prepare_for_chat_template=True
        )
        result1 = when2call_get_prompt_for_template_summarize(
            record=record, prepare_for_chat_template=True
        )
        self.assertEqual(result, result1)

    def test_query_openai(self):
        client = Mock()
        id = "1"
        dataset = "api_bank"
        result_text = "result"
        content = {"content": "result"}
        message = {"message": SimpleNamespace(**content)}
        response = {"choices": [SimpleNamespace(**message)]}
        client.chat.completions.create.return_value = SimpleNamespace(**response)
        prompt = [
            {
                "role": user_role,
                "content": "hello",
            },
            {
                "role": ai_role,
                "content": 'API-Request: [Get_All_Sessions()]->{"data": [{"session_name": "Hatha yoga"}]}',
            },
        ]
        result = query_openai(
            "1",
            "api_bank",
            prompt,
            default_openai_model_name,
            client,
            default_temperature,
            default_top_p,
        )
        self.assertEqual(result.id, id)
        self.assertEqual(result.data_set, dataset)
        self.assertEqual(result.response, result_text)

    def test_query_gemini(self):
        client = Mock()
        id = "1"
        dataset = "api_bank"
        result_text = "result"
        response = {"text": result_text}
        client.models.generate_content.return_value = SimpleNamespace(**response)
        prompt = [
            {
                "role": ai_role,
                "content": 'API-Request: [Get_All_Sessions()]->{"data": [{"session_name": "Hatha yoga"}]}',
            }
        ]
        result = query_gemini(
            "1",
            "api_bank",
            prompt,
            default_gemini_model_name,
            client,
            default_temperature,
            default_top_p,
        )
        self.assertEqual(result.id, id)
        self.assertEqual(result.data_set, dataset)
        self.assertEqual(result.response, result_text)

    def test_prompts_to_string(self):

        conversation = [
            {"role": "user", "content": "What's the weather like in Paris?"}
        ]
        result = prompts_to_string(conversation, chat_format_template, tokenizer)
        self.assertEqual(result, "<s>[INST] What's the weather like in Paris?[/INST]")
        result = prompts_to_string(conversation, chat_format_general, tokenizer)
        self.assertEqual(result, "What's the weather like in Paris?\n")

    def test_get_hf_llm_result(self):
        model = Mock()
        model.device = "cpu"
        model.name_or_path = "test_model"
        model.generate.return_value = torch.tensor([[1, 1, 1, 1]])
        record = Record.from_dict(api_bank_record_dict)
        with tempfile.NamedTemporaryFile(mode="w+", delete=True) as fout:
            get_hf_llm_result(
                output_file=fout.name,
                record_keys=["1"],
                records_with_key={"1": record},
                model=model,
                tokenizer=tokenizer,
                prompt_func=get_prompt_for_strict_api,
                batch_size=4,
                temperature=default_temperature,
                top_p=default_top_p,
                chat_format=chat_format_general,
                do_sample=True,
            )

    def test_prepare_source_and_target(self):
        record = Record.from_dict(api_bank_record_dict)
        source, target = prepare_source_and_target(
            DataMode.UtteranceToAPICall, {"1": record}, tokenizer, chat_format_general
        )
        self.assertEqual(source, [source_1])
        self.assertEqual(
            target, ["API-Request: [API_1(symptom='shortness of breath')]</s>"]
        )

        source, target = prepare_source_and_target(
            DataMode.UtteranceToSummary, {"1": record}, tokenizer, chat_format_general
        )
        self.assertEqual(source, [source_2])
        self.assertEqual(
            target,
            [
                "Call the `API_1` API with following parameter: `symptom` as `shortness of breath`</s>"
            ],
        )

        record = Record.from_dict(tool_ace_record_dict)
        source, target = prepare_source_and_target(
            DataMode.UtteranceToAPICall, {"1": record}, tokenizer, chat_format_template
        )
        self.assertEqual(source, [source_3])
        self.assertEqual(target, ["[Get All Servers(limit=5)]</s>"])

        source, target = prepare_source_and_target(
            DataMode.UtteranceToSummary, {"1": record}, tokenizer, chat_format_template
        )
        self.assertEqual(source, [source_4])
        self.assertEqual(
            target,
            [
                "Call the `Get All Servers` API with following parameters: `limit` as `5`</s>"
            ],
        )


if __name__ == "__main__":
    unittest.main()
