import unittest

from src.utility import user_role, Record
from src.when2call_prompt_functions import (
    when2call_parse_conversation_role,
    when2call_get_prompt_for_strict_api,
    when2call_get_prompt_for_summarize,
    when2call_get_prompt_for_template_summarize,
)
from tests.test_when2call_data_conversion import tool_call_dict

content_1 = 'Generate an tool calling request based on the input.\nThe tool calling request should be in JSON format like: {"name": "tool name", "arguments": {"argument name 1": "argument value 1", "argument name 2": "argument value 2", ...}.\nIf the request can be fulfilled with a tool provided but there\'s information missing, return "request_for_info".\nIf the request cannot be fulfilled with tools provided, return "cannot_answer".\nAvailable tools:\n{"name": "form5500_planname", "description": "Retrieves information about a specific plan based on the provided Employer Identification Number (EIN) using the Form 5500 API.", "parameters": {"type": "dict", "properties": {"ein": {"description": "Employer Identification Number (EIN) for the plan.", "type": "int", "default": "311334685"}}}, "required": ["ein"]}\n{"name": "getpowerplantbyradiusandgps", "description": "Fetches information about power plants within a specified radius from given GPS coordinates.", "parameters": {"type": "dict", "properties": {"latitude": {"description": "The latitude coordinate to search around.", "type": "int", "default": "27.6"}, "longitude": {"description": "The longitude coordinate to search around.", "type": "int", "default": "78.5"}, "page_number": {"description": "The page number for paginated results. Defaults to None.", "type": "int, optional", "default": ""}, "radius_km": {"description": "The radius in kilometers to search within. Defaults to 100 km if not provided.", "type": "int, optional", "default": ""}}}, "required": ["latitude", "longitude"]}\n{"name": "options_stock", "description": "Fetch option data for a specific stock, ETF, or index.", "parameters": {"type": "dict", "properties": {"symbol": {"description": "The stock symbol to retrieve option data for. Default is \'AAPL\'.", "type": "str", "default": "AAPL"}, "expiration": {"description": "The expiration date for the options in UNIX timestamp format. Default is \'1705622400\'.", "type": "str", "default": "1705622400"}}}, "required": ["symbol", "expiration"]}\nInput:'
content_2 = 'Do sum\n{"name": "form5500_planname", "description": "Retrieves information about a specific plan based on the provided Employer Identification Number (EIN) using the Form 5500 API.", "parameters": {"type": "dict", "properties": {"ein": {"description": "Employer Identification Number (EIN) for the plan.", "type": "int", "default": "311334685"}}}, "required": ["ein"]}\n{"name": "getpowerplantbyradiusandgps", "description": "Fetches information about power plants within a specified radius from given GPS coordinates.", "parameters": {"type": "dict", "properties": {"latitude": {"description": "The latitude coordinate to search around.", "type": "int", "default": "27.6"}, "longitude": {"description": "The longitude coordinate to search around.", "type": "int", "default": "78.5"}, "page_number": {"description": "The page number for paginated results. Defaults to None.", "type": "int, optional", "default": ""}, "radius_km": {"description": "The radius in kilometers to search within. Defaults to 100 km if not provided.", "type": "int, optional", "default": ""}}}, "required": ["latitude", "longitude"]}\n{"name": "options_stock", "description": "Fetch option data for a specific stock, ETF, or index.", "parameters": {"type": "dict", "properties": {"symbol": {"description": "The stock symbol to retrieve option data for. Default is \'AAPL\'.", "type": "str", "default": "AAPL"}, "expiration": {"description": "The expiration date for the options in UNIX timestamp format. Default is \'1705622400\'.", "type": "str", "default": "1705622400"}}}, "required": ["symbol", "expiration"]}\nInput:'
content_3 = 'Based on the input, summarize the next action to take.\nIf the action can be fulfilled with tools provided, summarized result should contain all necessary information defined in corresponding tool.\nSummarize in the format like following: \nIf there is no parameter: Call the `tool name` API with no parameter.\nIf there is one or more parameters: Call the `tool name` API with following parameters: `parameter name` as `parameter value`, ....\nIf the action can be fulfilled with a tool provided but there is information missing, return "request_for_info".\nIf the action cannot be fulfilled with tools provided, return "cannot_answer".\nAvailable tools:\n{"name": "form5500_planname", "description": "Retrieves information about a specific plan based on the provided Employer Identification Number (EIN) using the Form 5500 API.", "parameters": {"type": "dict", "properties": {"ein": {"description": "Employer Identification Number (EIN) for the plan.", "type": "int", "default": "311334685"}}}, "required": ["ein"]}\n{"name": "getpowerplantbyradiusandgps", "description": "Fetches information about power plants within a specified radius from given GPS coordinates.", "parameters": {"type": "dict", "properties": {"latitude": {"description": "The latitude coordinate to search around.", "type": "int", "default": "27.6"}, "longitude": {"description": "The longitude coordinate to search around.", "type": "int", "default": "78.5"}, "page_number": {"description": "The page number for paginated results. Defaults to None.", "type": "int, optional", "default": ""}, "radius_km": {"description": "The radius in kilometers to search within. Defaults to 100 km if not provided.", "type": "int, optional", "default": ""}}}, "required": ["latitude", "longitude"]}\n{"name": "options_stock", "description": "Fetch option data for a specific stock, ETF, or index.", "parameters": {"type": "dict", "properties": {"symbol": {"description": "The stock symbol to retrieve option data for. Default is \'AAPL\'.", "type": "str", "default": "AAPL"}, "expiration": {"description": "The expiration date for the options in UNIX timestamp format. Default is \'1705622400\'.", "type": "str", "default": "1705622400"}}}, "required": ["symbol", "expiration"]}\nInput:'
utterance = "Fetch option data for Tesla (TSLA) with an expiration date of 1672502400."


class When2CallPromptFunctionsTestCases(unittest.TestCase):
    def test_when2call_parse_conversation_role(self):
        result = when2call_parse_conversation_role("hello")
        self.assertEqual(result, {"role": user_role, "content": "hello"})

    def test_when2call_get_prompt_for_strict_api(self):
        record = Record.from_dict(tool_call_dict)
        prompt = when2call_get_prompt_for_strict_api(record, True)
        self.assertEqual(len(prompt), 2)
        self.assertEqual(prompt[0]["role"], "system")
        self.assertEqual(prompt[0]["content"], content_1)
        self.assertEqual(prompt[1]["role"], "user")
        self.assertEqual(prompt[1]["content"], utterance)

        prompt = when2call_get_prompt_for_strict_api(record, False)
        self.assertEqual(len(prompt), 2)
        self.assertEqual(prompt[0]["role"], "system")
        self.assertEqual(prompt[0]["content"], content_1)
        self.assertEqual(prompt[1]["content"], utterance)

    def test_when2call_get_prompt_for_summarize(self):
        record = Record.from_dict(tool_call_dict)
        prompt = when2call_get_prompt_for_summarize(record, True, ["Do sum"])
        self.assertEqual(len(prompt), 2)
        self.assertEqual(prompt[0]["role"], "system")
        self.assertEqual(prompt[0]["content"], content_2)
        self.assertEqual(prompt[1]["role"], "user")
        self.assertEqual(prompt[1]["content"], utterance)

        prompt = when2call_get_prompt_for_summarize(record, False, ["Do sum"])
        self.assertEqual(len(prompt), 2)
        self.assertEqual(prompt[0]["role"], "system")
        self.assertEqual(prompt[0]["content"], content_2)
        self.assertEqual(prompt[1]["content"], utterance)

    def test_when2call_get_prompt_for_template_summarize(self):
        record = Record.from_dict(tool_call_dict)
        prompt = when2call_get_prompt_for_template_summarize(record, True)
        self.assertEqual(len(prompt), 2)
        self.assertEqual(prompt[0]["role"], "system")
        self.assertEqual(prompt[0]["content"], content_3)
        self.assertEqual(prompt[1]["role"], "user")
        self.assertEqual(prompt[1]["content"], utterance)

        prompt = when2call_get_prompt_for_template_summarize(record, False)
        self.assertEqual(len(prompt), 2)
        self.assertEqual(prompt[0]["role"], "system")
        self.assertEqual(prompt[0]["content"], content_3)
        self.assertEqual(prompt[1]["content"], utterance)


if __name__ == "__main__":
    unittest.main()
