import unittest

from src.data_conversion import (
    str_to_api_call,
    api_bank_api_call_status_to_api_str,
    tool_ace_api_call_status_to_api_str,
    when2call_api_call_status_to_api_str,
    api_call_status_to_api_str,
    api_bank_tool_ace_parse_template,
    when2call_parse_template,
    template_to_api_call,
)
from src.utility import (
    dataset_api_bank,
    API_TOOL_CALL,
    API_INVALID_RESPONSE,
    dataset_tool_ace,
    dataset_when2call,
    APICallStatus,
    APICall,
    API_CAN_NOT_ANSWER,
    API_REQUEST_FOR_INFO,
)


class DataConversionTestCases(unittest.TestCase):
    def test_str_to_api_call(self):
        text = "API-Request: [EmergencyKnowledge(symptom='Fatigue')]"
        result = str_to_api_call(text, dataset_api_bank)
        self.assertEqual(result.api_call_status, API_TOOL_CALL)
        self.assertEqual(result.api_calls[0].api_name, "EmergencyKnowledge")
        self.assertEqual(result.api_calls[0].params["symptom"], "Fatigue")

        text = "API-Request: [EmergencyKnowledge(symptom='Fatigue]"
        self.assertRaises(RuntimeError, str_to_api_call, text, dataset_api_bank)
        result = str_to_api_call(text, dataset_api_bank, False)
        self.assertEqual(result.api_call_status, API_INVALID_RESPONSE)

        text = '[API_1(name="ETH", date="2010")]'
        result = str_to_api_call(text, dataset_tool_ace)
        self.assertEqual(result.api_call_status, API_TOOL_CALL)
        self.assertEqual(result.api_calls[0].api_name, "API_1")
        self.assertEqual(result.api_calls[0].params["name"], "ETH")
        self.assertEqual(result.api_calls[0].params["date"], "2010")

        text = "[API_1(name"
        self.assertRaises(RuntimeError, str_to_api_call, text, dataset_tool_ace)
        result = str_to_api_call(text, dataset_tool_ace, False)
        self.assertEqual(result.api_call_status, API_INVALID_RESPONSE)

        text = '{"name": "options_stock", "arguments": {"symbol": "TSLA"}}'
        result = str_to_api_call(text, dataset_when2call)
        self.assertEqual(result.api_call_status, API_TOOL_CALL)
        self.assertEqual(result.api_calls[0].api_name, "options_stock")
        self.assertEqual(result.api_calls[0].params["symbol"], "TSLA")

        self.assertRaises(RuntimeError, str_to_api_call, text, "unknown")

    def test_api_bank_api_call_status_to_api_str(self):
        api_call = APICallStatus(
            api_call_status=API_TOOL_CALL,
            api_calls=[APICall(api_name="API_1", params={"param": "value"})],
        )
        result = api_bank_api_call_status_to_api_str(api_call)
        self.assertEqual(result, "API-Request: [API_1(param='value')]")

    def test_tool_ace_api_call_status_to_api_str(self):
        api_call = APICallStatus(
            api_call_status=API_TOOL_CALL,
            api_calls=[APICall(api_name="API_1", params={"param": "value"})],
        )
        result = tool_ace_api_call_status_to_api_str(api_call)
        self.assertEqual(result, '[API_1(param="value")]')

    def test_when2call_api_call_status_to_api_str(self):
        api_call = APICallStatus(
            api_call_status=API_TOOL_CALL,
            api_calls=[APICall(api_name="API_1", params={"param": "value"})],
        )
        result = when2call_api_call_status_to_api_str(api_call)
        self.assertEqual(result, '{"name": "API_1", "arguments": {"param": "value"}}')

        api_call = APICallStatus(
            api_call_status=API_CAN_NOT_ANSWER,
            api_calls=[APICall(api_name="API_1", params={"param": "value"})],
        )
        result = when2call_api_call_status_to_api_str(api_call)
        self.assertEqual(result, API_CAN_NOT_ANSWER)

        api_call = APICallStatus(
            api_call_status=API_REQUEST_FOR_INFO,
            api_calls=[APICall(api_name="API_1", params={"param": "value"})],
        )
        result = when2call_api_call_status_to_api_str(api_call)
        self.assertEqual(result, API_REQUEST_FOR_INFO)

    def test_api_call_status_to_api_str(self):
        api_call = APICallStatus(
            api_call_status=API_TOOL_CALL,
            api_calls=[APICall(api_name="API_1", params={"param": "value"})],
        )
        result = api_call_status_to_api_str(dataset_api_bank, api_call)
        self.assertEqual(result, api_bank_api_call_status_to_api_str(api_call))
        result = api_call_status_to_api_str(dataset_tool_ace, api_call)
        self.assertEqual(result, tool_ace_api_call_status_to_api_str(api_call))
        result = api_call_status_to_api_str(dataset_when2call, api_call)
        self.assertEqual(result, when2call_api_call_status_to_api_str(api_call))

    def test_api_bank_tool_ace_parse_template(self):
        text = "Call the `API_1` API with no parameter"
        result = api_bank_tool_ace_parse_template(text)
        self.assertEqual(result.api_call_status, API_TOOL_CALL)
        self.assertEqual(len(result.api_calls), 1)
        self.assertEqual(result.api_calls[0].api_name, "API_1")
        self.assertEqual(len(result.api_calls[0].params), 0)

        text = "Call the `API_1` API with following parameters: `param1` as `value1`, `param2` as `value2`"
        result = api_bank_tool_ace_parse_template(text)
        self.assertEqual(result.api_call_status, API_TOOL_CALL)
        self.assertEqual(len(result.api_calls), 1)
        self.assertEqual(result.api_calls[0].api_name, "API_1")
        self.assertEqual(len(result.api_calls[0].params), 2)
        self.assertEqual(result.api_calls[0].params["param1"], "value1")
        self.assertEqual(result.api_calls[0].params["param2"], "value2")

        text = (
            "Call the `API_1` API with no parameter\n"
            "Call the `API_2` API with following parameters: `param1` as `value1`, `param2` as `value2`"
        )
        result = api_bank_tool_ace_parse_template(text)
        self.assertEqual(result.api_call_status, API_TOOL_CALL)
        self.assertEqual(len(result.api_calls), 2)
        self.assertEqual(result.api_calls[0].api_name, "API_1")
        self.assertEqual(len(result.api_calls[0].params), 0)
        self.assertEqual(result.api_calls[1].api_name, "API_2")
        self.assertEqual(len(result.api_calls[1].params), 2)
        self.assertEqual(result.api_calls[1].params["param1"], "value1")
        self.assertEqual(result.api_calls[1].params["param2"], "value2")

    def test_when2call_parse_template(self):
        text = f"Result:{API_CAN_NOT_ANSWER}"
        result = when2call_parse_template(text)
        self.assertEqual(result.api_call_status, API_CAN_NOT_ANSWER)
        self.assertEqual(len(result.api_calls), 0)

        text = f"Result:{API_REQUEST_FOR_INFO}"
        result = when2call_parse_template(text)
        self.assertEqual(result.api_call_status, API_REQUEST_FOR_INFO)
        self.assertEqual(len(result.api_calls), 0)

        text = "Call the `API_2` API with following parameters: `param1` as `value1`, `param2` as `value2`"
        result = when2call_parse_template(text)
        self.assertEqual(result.api_call_status, API_TOOL_CALL)
        self.assertEqual(len(result.api_calls), 1)
        self.assertEqual(result.api_calls[0].api_name, "API_2")
        self.assertEqual(len(result.api_calls[0].params), 2)
        self.assertEqual(result.api_calls[0].params["param1"], "value1")
        self.assertEqual(result.api_calls[0].params["param2"], "value2")

    def test_template_to_api_call(self):
        text = "Call the `API_2` API with following parameters: `param1` as `value1`, `param2` as `value2`"
        result = api_bank_tool_ace_parse_template(text)
        self.assertEqual(result, template_to_api_call(text, dataset_api_bank))
        self.assertEqual(result, template_to_api_call(text, dataset_tool_ace))
        result = when2call_parse_template(text)
        self.assertEqual(result, template_to_api_call(text, dataset_when2call))


if __name__ == "__main__":
    unittest.main()
