import unittest

from src.utility import (
    Record,
    API_CAN_NOT_ANSWER,
    API_REQUEST_FOR_INFO,
    API_TOOL_CALL,
    API_INVALID_RESPONSE,
    APICallStatus,
    APICall,
)
from src.when2call_data_conversion import (
    when2call_process_train_data,
    when2call_process_test_data,
    when2call_str_to_api_call,
    when2call_api_call_to_template,
)

tool_call_record = {
    "tools": [
        '{"name": "form5500_planname", "description": "Retrieves information about a specific plan based on the provided Employer Identification Number (EIN) using the Form 5500 API.", "parameters": {"type": "dict", "properties": {"ein": {"description": "Employer Identification Number (EIN) for the plan.", "type": "int", "default": "311334685"}}}, "required": ["ein"]}',
        '{"name": "getpowerplantbyradiusandgps", "description": "Fetches information about power plants within a specified radius from given GPS coordinates.", "parameters": {"type": "dict", "properties": {"latitude": {"description": "The latitude coordinate to search around.", "type": "int", "default": "27.6"}, "longitude": {"description": "The longitude coordinate to search around.", "type": "int", "default": "78.5"}, "page_number": {"description": "The page number for paginated results. Defaults to None.", "type": "int, optional", "default": ""}, "radius_km": {"description": "The radius in kilometers to search within. Defaults to 100 km if not provided.", "type": "int, optional", "default": ""}}}, "required": ["latitude", "longitude"]}',
        '{"name": "options_stock", "description": "Fetch option data for a specific stock, ETF, or index.", "parameters": {"type": "dict", "properties": {"symbol": {"description": "The stock symbol to retrieve option data for. Default is \'AAPL\'.", "type": "str", "default": "AAPL"}, "expiration": {"description": "The expiration date for the options in UNIX timestamp format. Default is \'1705622400\'.", "type": "str", "default": "1705622400"}}}, "required": ["symbol", "expiration"]}',
    ],
    "messages": [
        {
            "role": "user",
            "content": "Fetch option data for Tesla (TSLA) with an expiration date of 1672502400.",
        }
    ],
    "chosen_response": {
        "role": "assistant",
        "content": '<TOOLCALL>[{"name": "options_stock", "arguments": {"symbol": "TSLA", "expiration": "1672502400"}}]</TOOLCALL>',
    },
    "rejected_response": {
        "role": "assistant",
        "content": "The option chain for TSLA with an expiration date of 1672502400 includes calls with strike prices ranging from $600 to $1200, and puts with strike prices ranging from $500 to $1100. Please refer to a financial platform for specific option data.",
    },
}

more_info_record = {
    "tools": [
        '{"name": "get_dna_sequence", "description": "Retrieves the DNA sequence for a given sequence ID from the NCBI Nucleotide database.", "parameters": {"type": "dict", "properties": {"sequence_id": {"description": "The unique identifier for the DNA sequence.", "type": "str", "default": "fasta"}, "file_format": {"description": "The format of the returned sequence. Allowed values: \\"fasta\\" (default) or \\"gb\\".", "type": "str, optional"}, "upstream_bases": {"description": "The number of bases upstream of the sequence to include. Defaults to 0.", "type": "int, optional", "default": "fasta"}}}, "required": ["sequence_id"]}'
    ],
    "messages": [
        {
            "role": "user",
            "content": "Fetch a sequence in GenBank format, include 200 upstream bases.",
        }
    ],
    "chosen_response": {
        "role": "assistant",
        "content": "To proceed, I'll need the unique identifier for the DNA sequence. Could you please provide the sequence ID?",
    },
    "rejected_response": {
        "role": "assistant",
        "content": "The sequence AY590582 in GenBank format, including 200 upstream bases, has been fetched and is ready for your review.",
    },
}

can_not_record = {
    "tools": [
        '{"name": "get_range", "description": "Helper function to format the range string.", "parameters": {"type": "dict", "properties": {"start": {"description": "The start of the range.", "type": "int"}, "end": {"description": "The end of the range.", "type": "int"}}}, "required": ["start", "end"]}',
        '{"name": "calculate_grade", "description": "Calculates the weighted average grade based on scores and their corresponding weights.", "parameters": {"type": "dict", "properties": {"scores": {"description": "A list of scores.", "type": "List[float]"}, "weights": {"description": "A list of weights corresponding to each score.", "type": "List[float]"}}}, "required": ["scores", "weights"]}',
    ],
    "messages": [
        {
            "role": "user",
            "content": "Count the occurrences of each fruit in the list ['apple', 'banana', 'apple', 'orange', 'banana', 'banana'].",
        }
    ],
    "chosen_response": {
        "role": "assistant",
        "content": "I'm sorry for the inconvenience, but I'm currently unable to perform tasks such as counting occurrences in a list.",
    },
    "rejected_response": {
        "role": "assistant",
        "content": "In the list, 'apple' appears 2 times, 'banana' appears 3 times, and 'orange' appears 1 time.",
    },
}

tool_call_dict = {
    "id": "5b8efae8-0378-4951-b82c-055f424a720f",
    "data_set": "when2call",
    "pre_api": [],
    "api_def": [
        '{"name": "form5500_planname", "description": "Retrieves information about a specific plan based on the provided Employer Identification Number (EIN) using the Form 5500 API.", "parameters": {"type": "dict", "properties": {"ein": {"description": "Employer Identification Number (EIN) for the plan.", "type": "int", "default": "311334685"}}}, "required": ["ein"]}',
        '{"name": "getpowerplantbyradiusandgps", "description": "Fetches information about power plants within a specified radius from given GPS coordinates.", "parameters": {"type": "dict", "properties": {"latitude": {"description": "The latitude coordinate to search around.", "type": "int", "default": "27.6"}, "longitude": {"description": "The longitude coordinate to search around.", "type": "int", "default": "78.5"}, "page_number": {"description": "The page number for paginated results. Defaults to None.", "type": "int, optional", "default": ""}, "radius_km": {"description": "The radius in kilometers to search within. Defaults to 100 km if not provided.", "type": "int, optional", "default": ""}}}, "required": ["latitude", "longitude"]}',
        '{"name": "options_stock", "description": "Fetch option data for a specific stock, ETF, or index.", "parameters": {"type": "dict", "properties": {"symbol": {"description": "The stock symbol to retrieve option data for. Default is \'AAPL\'.", "type": "str", "default": "AAPL"}, "expiration": {"description": "The expiration date for the options in UNIX timestamp format. Default is \'1705622400\'.", "type": "str", "default": "1705622400"}}}, "required": ["symbol", "expiration"]}',
    ],
    "conversation": [
        "Fetch option data for Tesla (TSLA) with an expiration date of 1672502400."
    ],
    "ending": [],
    "output": '{"name": "options_stock", "arguments": {"symbol": "TSLA", "expiration": "1672502400"}}',
    "post_api": [],
    "api_call": {
        "api_call_status": "tool_call",
        "api_calls": [
            {
                "api_name": "options_stock",
                "params": {"symbol": "TSLA", "expiration": "1672502400"},
            }
        ],
    },
    "template_output": [
        "Call the `options_stock` API with following parameters: `symbol` as `TSLA`, `expiration` as `1672502400`"
    ],
}

more_info_dict = {
    "id": "5b8efae8-0378-4951-b82c-055f424a720f",
    "data_set": "when2call",
    "pre_api": [],
    "api_def": [
        '{"name": "get_dna_sequence", "description": "Retrieves the DNA sequence for a given sequence ID from the NCBI Nucleotide database.", "parameters": {"type": "dict", "properties": {"sequence_id": {"description": "The unique identifier for the DNA sequence.", "type": "str", "default": "fasta"}, "file_format": {"description": "The format of the returned sequence. Allowed values: \\"fasta\\" (default) or \\"gb\\".", "type": "str, optional"}, "upstream_bases": {"description": "The number of bases upstream of the sequence to include. Defaults to 0.", "type": "int, optional", "default": "fasta"}}}, "required": ["sequence_id"]}'
    ],
    "conversation": ["Fetch a sequence in GenBank format, include 200 upstream bases."],
    "ending": [],
    "output": "request_for_info",
    "post_api": [],
    "api_call": {
        "api_call_status": "request_for_info",
        "api_calls": [],
    },
    "template_output": ["request_for_info"],
}

can_not_dict = {
    "id": "5b8efae8-0378-4951-b82c-055f424a720f",
    "data_set": "when2call",
    "pre_api": [],
    "api_def": [
        '{"name": "get_range", "description": "Helper function to format the range string.", "parameters": {"type": "dict", "properties": {"start": {"description": "The start of the range.", "type": "int"}, "end": {"description": "The end of the range.", "type": "int"}}}, "required": ["start", "end"]}',
        '{"name": "calculate_grade", "description": "Calculates the weighted average grade based on scores and their corresponding weights.", "parameters": {"type": "dict", "properties": {"scores": {"description": "A list of scores.", "type": "List[float]"}, "weights": {"description": "A list of weights corresponding to each score.", "type": "List[float]"}}}, "required": ["scores", "weights"]}',
    ],
    "conversation": [
        "Count the occurrences of each fruit in the list ['apple', 'banana', 'apple', 'orange', 'banana', 'banana']."
    ],
    "ending": [],
    "output": "cannot_answer",
    "post_api": [],
    "api_call": {
        "api_call_status": "cannot_answer",
        "api_calls": [],
    },
    "template_output": ["cannot_answer"],
}

tool_call_test_record = {
    "uuid": "48359a3e-d6e3-410b-b5df-3dd61180b40b",
    "source": "BFCL v2 Live Multiple",
    "source_id": "live_multiple_177-74-0",
    "question": "I'm starting from New York, NY. Could you provide me with details for a trip, including potential destinations, estimated costs, and any travel advisories?",
    "correct_answer": "tool_call",
    "answers": {
        "direct": "Sure! Consider Washington D.C., a 4-hour drive. Budget $200/day for lodging and meals. Visit the Smithsonian museums for free. No travel advisories currently. Enjoy your trip!",
        "tool_call": '{"name": "get_trip", "arguments": {"location": "New York, NY"}}',
        "request_for_info": "To help you better, could you please specify the type of trip you're interested in? For example, are you looking for a beach vacation, a city break, or a nature retreat?",
        "cannot_answer": "I'm sorry for the inconvenience, but I'm unable to provide detailed trip information. I recommend using a dedicated travel planning service for accurate details on destinations, costs, and advisories.",
    },
    "target_tool": '{"name": "get_trip", "description": "Retrieve trip details for a specified location, including possible destinations, estimated costs, and travel advisories.", "parameters": {"type": "dict", "required": ["location"], "properties": {"location": {"type": "string", "description": "The starting location for the trip, in the format of \'City, State\', such as \'Berkeley, CA\' or \'New York, NY\'."}}}}',
    "tools": [
        '{"name": "sort_array", "description": "Sorts an array of integers in ascending order.", "parameters": {"type": "dict", "required": ["array"], "properties": {"array": {"type": "array", "items": {"type": "integer"}, "description": "An array of integers to be sorted."}, "reverse": {"type": "boolean", "description": "Determines if the array should be sorted in descending order. Defaults to false (ascending order).", "default": false}}}}',
        '{"name": "get_trip", "description": "Retrieve trip details for a specified location, including possible destinations, estimated costs, and travel advisories.", "parameters": {"type": "dict", "required": ["location"], "properties": {"location": {"type": "string", "description": "The starting location for the trip, in the format of \'City, State\', such as \'Berkeley, CA\' or \'New York, NY\'."}}}}',
    ],
}
can_not_test_record = {
    "uuid": "c38c6471-c320-451a-bfe5-e85dbc1c49b2",
    "source": "BFCL v2 Live Multiple",
    "source_id": "live_multiple_3-2-0",
    "question": "Get weather of Ha Noi for me",
    "correct_answer": "cannot_answer",
    "answers": {
        "direct": "The current weather in Hanoi is partly cloudy with a temperature of 28°C (82°F).",
        "tool_call": '{"name": "api.weather", "arguments": {"loc": "Ha Noi, Vietnam"}}',
        "request_for_info": "Sure, I can help with that. To provide the most accurate information, could you please specify if you're interested in the current weather or the forecast for a specific day?",
        "cannot_answer": "I'm sorry, I'm unable to provide real-time weather information as I don't have that capability. Please check a reliable weather service for the current weather in Hanoi.",
    },
    "target_tool": None,
    "tools": [
        '{"name": "uber.ride", "description": "Finds a suitable Uber ride for the customer based on the starting location, the desired ride type, and the maximum wait time the customer is willing to accept.", "parameters": {"type": "dict", "required": ["loc", "type", "time"], "properties": {"loc": {"type": "string", "description": "The starting location for the Uber ride, in the format of \'Street Address, City, State\', such as \'123 Main St, Springfield, IL\'."}, "type": {"type": "string", "description": "The type of Uber ride the user is ordering.", "enum": ["plus", "comfort", "black"]}, "time": {"type": "integer", "description": "The maximum amount of time the customer is willing to wait for the ride, in minutes."}}}}'
    ],
    "orig_tools": [
        '{"name": "uber.ride", "description": "Finds a suitable Uber ride for the customer based on the starting location, the desired ride type, and the maximum wait time the customer is willing to accept.", "parameters": {"type": "dict", "required": ["loc", "type", "time"], "properties": {"loc": {"type": "string", "description": "The starting location for the Uber ride, in the format of \'Street Address, City, State\', such as \'123 Main St, Springfield, IL\'."}, "type": {"type": "string", "description": "The type of Uber ride the user is ordering.", "enum": ["plus", "comfort", "black"]}, "time": {"type": "integer", "description": "The maximum amount of time the customer is willing to wait for the ride, in minutes."}}}}',
        '{"name": "api.weather", "description": "Retrieve current weather information for a specified location.", "parameters": {"type": "dict", "required": ["loc"], "properties": {"loc": {"type": "string", "description": "The location for which weather information is to be retrieved, in the format of \'City, Country\' (e.g., \'Paris, France\')."}}}}',
    ],
}
more_info_test_record = {
    "uuid": "25f7d42b-3b35-417d-916f-32bb2068cf93",
    "source": "BFCL v2 Live Multiple",
    "source_id": "live_multiple_27-7-0",
    "question": "If I have 100$ and made a donation, how much do I have now?",
    "orig_question": "If I have 100$ and I donated 40. How much do I have now?",
    "correct_answer": "request_for_info",
    "answers": {
        "direct": "You now have $60 remaining.",
        "tool_call": '{"name": "add", "arguments": {"a": 100, "b": -40}}',
        "request_for_info": "To help you with that, I need to know the amount of the donation you made. Could you please provide that?",
        "cannot_answer": "Apologies, but I'm unable to perform calculations. Based on your statement, if you started with $100 and donated $40, you would have $60 remaining.",
    },
    "target_tool": '{"name": "add", "description": "Performs the addition of two integers and returns the result.", "parameters": {"type": "dict", "required": ["a", "b"], "properties": {"a": {"type": "integer", "description": "The first integer to be added."}, "b": {"type": "integer", "description": "The second integer to be added."}}}}',
    "tools": [
        '{"name": "multiply", "description": "Multiplies two integers and returns the product.", "parameters": {"type": "dict", "required": ["a", "b"], "properties": {"a": {"type": "integer", "description": "The first integer to be multiplied."}, "b": {"type": "integer", "description": "The second integer to be multiplied."}}}}',
        '{"name": "add", "description": "Performs the addition of two integers and returns the result.", "parameters": {"type": "dict", "required": ["a", "b"], "properties": {"a": {"type": "integer", "description": "The first integer to be added."}, "b": {"type": "integer", "description": "The second integer to be added."}}}}',
    ],
    "held_out_param": "b",
}
tool_call_test_dict = {
    "id": "5b8efae8-0378-4951-b82c-055f424a720f",
    "data_set": "when2call",
    "pre_api": [],
    "api_def": [
        '{"name": "sort_array", "description": "Sorts an array of integers in ascending order.", "parameters": {"type": "dict", "required": ["array"], "properties": {"array": {"type": "array", "items": {"type": "integer"}, "description": "An array of integers to be sorted."}, "reverse": {"type": "boolean", "description": "Determines if the array should be sorted in descending order. Defaults to false (ascending order).", "default": false}}}}',
        '{"name": "get_trip", "description": "Retrieve trip details for a specified location, including possible destinations, estimated costs, and travel advisories.", "parameters": {"type": "dict", "required": ["location"], "properties": {"location": {"type": "string", "description": "The starting location for the trip, in the format of \'City, State\', such as \'Berkeley, CA\' or \'New York, NY\'."}}}}',
    ],
    "conversation": [
        "I'm starting from New York, NY. Could you provide me with details for a trip, including potential destinations, estimated costs, and any travel advisories?"
    ],
    "ending": [],
    "output": '{"name": "get_trip", "arguments": {"location": "New York, NY"}}',
    "post_api": [],
    "api_call": {
        "api_call_status": "tool_call",
        "api_calls": [
            {
                "api_name": "get_trip",
                "params": {"location": "New York, NY"},
            }
        ],
    },
    "template_output": [
        "Call the `get_trip` API with following parameters: `location` as `New York, NY`"
    ],
}

more_info_test_dict = {
    "id": "5b8efae8-0378-4951-b82c-055f424a720f",
    "data_set": "when2call",
    "pre_api": [],
    "api_def": [
        '{"name": "multiply", "description": "Multiplies two integers and returns the product.", "parameters": {"type": "dict", "required": ["a", "b"], "properties": {"a": {"type": "integer", "description": "The first integer to be multiplied."}, "b": {"type": "integer", "description": "The second integer to be multiplied."}}}}',
        '{"name": "add", "description": "Performs the addition of two integers and returns the result.", "parameters": {"type": "dict", "required": ["a", "b"], "properties": {"a": {"type": "integer", "description": "The first integer to be added."}, "b": {"type": "integer", "description": "The second integer to be added."}}}}',
    ],
    "conversation": ["If I have 100$ and made a donation, how much do I have now?"],
    "ending": [],
    "output": "request_for_info",
    "post_api": [],
    "api_call": {
        "api_call_status": "request_for_info",
        "api_calls": [],
    },
    "template_output": ["request_for_info"],
}

can_not_test_dict = {
    "id": "5b8efae8-0378-4951-b82c-055f424a720f",
    "data_set": "when2call",
    "pre_api": [],
    "api_def": [
        '{"name": "uber.ride", "description": "Finds a suitable Uber ride for the customer based on the starting location, the desired ride type, and the maximum wait time the customer is willing to accept.", "parameters": {"type": "dict", "required": ["loc", "type", "time"], "properties": {"loc": {"type": "string", "description": "The starting location for the Uber ride, in the format of \'Street Address, City, State\', such as \'123 Main St, Springfield, IL\'."}, "type": {"type": "string", "description": "The type of Uber ride the user is ordering.", "enum": ["plus", "comfort", "black"]}, "time": {"type": "integer", "description": "The maximum amount of time the customer is willing to wait for the ride, in minutes."}}}}'
    ],
    "conversation": ["Get weather of Ha Noi for me"],
    "ending": [],
    "output": "cannot_answer",
    "post_api": [],
    "api_call": {
        "api_call_status": "cannot_answer",
        "api_calls": [],
    },
    "template_output": ["cannot_answer"],
}


class When2CallDataConversionTestCases(unittest.TestCase):
    def test_when2call_process_train_data(self):
        record = when2call_process_train_data(tool_call_record)
        tool_call_dict["id"] = record.id
        record_expected = Record.from_dict(tool_call_dict)
        self.assertEqual(record, record_expected)

        record = when2call_process_train_data(more_info_record)
        more_info_dict["id"] = record.id
        record_expected = Record.from_dict(more_info_dict)
        self.assertEqual(record, record_expected)

        record = when2call_process_train_data(can_not_record)
        can_not_dict["id"] = record.id
        record_expected = Record.from_dict(can_not_dict)
        self.assertEqual(record, record_expected)

    def test_when2call_process_test_data(self):
        record = when2call_process_test_data(tool_call_test_record)
        tool_call_test_dict["id"] = record.id
        record_expected = Record.from_dict(tool_call_test_dict)
        self.assertEqual(record, record_expected)

        record = when2call_process_test_data(more_info_test_record)
        more_info_test_dict["id"] = record.id
        record_expected = Record.from_dict(more_info_test_dict)
        self.assertEqual(record, record_expected)

        record = when2call_process_test_data(can_not_test_record)
        can_not_test_dict["id"] = record.id
        record_expected = Record.from_dict(can_not_test_dict)
        self.assertEqual(record, record_expected)

    def test_when2call_str_to_api_call(self):
        result = when2call_str_to_api_call(API_CAN_NOT_ANSWER, True)
        self.assertEqual(result.api_call_status, API_CAN_NOT_ANSWER)
        self.assertEqual(len(result.api_calls), 0)

        result = when2call_str_to_api_call(API_CAN_NOT_ANSWER, False)
        self.assertEqual(result.api_call_status, API_CAN_NOT_ANSWER)
        self.assertEqual(len(result.api_calls), 0)

        result = when2call_str_to_api_call(API_REQUEST_FOR_INFO, True)
        self.assertEqual(result.api_call_status, API_REQUEST_FOR_INFO)
        self.assertEqual(len(result.api_calls), 0)

        result = when2call_str_to_api_call(API_REQUEST_FOR_INFO, False)
        self.assertEqual(result.api_call_status, API_REQUEST_FOR_INFO)
        self.assertEqual(len(result.api_calls), 0)

        text = '{"name": "options_stock", "arguments": {"symbol": "TSLA", "expiration": "1672502400"}}'
        result = when2call_str_to_api_call(text, True)
        self.assertEqual(result.api_call_status, API_TOOL_CALL)
        self.assertEqual(len(result.api_calls), 1)
        self.assertEqual(result.api_calls[0].api_name, "options_stock")
        self.assertEqual(
            result.api_calls[0].params, {"symbol": "TSLA", "expiration": "1672502400"}
        )
        result = when2call_str_to_api_call(text, False)
        self.assertEqual(result.api_call_status, API_TOOL_CALL)
        self.assertEqual(len(result.api_calls), 1)
        self.assertEqual(result.api_calls[0].api_name, "options_stock")
        self.assertEqual(
            result.api_calls[0].params, {"symbol": "TSLA", "expiration": "1672502400"}
        )

        text = '{"name": "options_stock", "arguments": {"symbol": "TSLA"'
        self.assertRaises(RuntimeError, when2call_str_to_api_call, text, True)
        result = when2call_str_to_api_call(text, False)
        self.assertEqual(result.api_call_status, API_INVALID_RESPONSE)

    def test_when2call_api_call_to_template(self):
        self.assertRaises(
            RuntimeError,
            when2call_api_call_to_template,
            APICallStatus(api_call_status=API_INVALID_RESPONSE),
        )

        result = when2call_api_call_to_template(
            APICallStatus(api_call_status=API_CAN_NOT_ANSWER)
        )
        self.assertEqual(result, [API_CAN_NOT_ANSWER])

        result = when2call_api_call_to_template(
            APICallStatus(api_call_status=API_REQUEST_FOR_INFO)
        )
        self.assertEqual(result, [API_REQUEST_FOR_INFO])

        result = when2call_api_call_to_template(
            APICallStatus(
                api_call_status=API_TOOL_CALL,
                api_calls=[APICall(api_name="API_1", params={})],
            )
        )
        self.assertEqual(result, ["Call the `API_1` API with no parameter"])

        result = when2call_api_call_to_template(
            APICallStatus(
                api_call_status=API_TOOL_CALL,
                api_calls=[
                    APICall(
                        api_name="API_2",
                        params={"param1": "value1", "param2": "value2"},
                    )
                ],
            )
        )
        self.assertEqual(
            result,
            [
                "Call the `API_2` API with following parameters: `param1` as `value1`, `param2` as `value2`"
            ],
        )


if __name__ == "__main__":
    unittest.main()
