import random
import string
import unittest

from src.tool_ace_data_conversion import (
    tool_ace_parse_api_request,
    parse_system_prompt,
    tool_ace_str_to_api_call,
    tool_ace_api_call_to_template,
    parse_conversation,
    raw_dict_to_record,
)
from src.utility import APICall, generate_id


class ToolAceDataConversionTestCases(unittest.TestCase):
    def test_parse_system_prompt(self):
        prompt = 'You are an expert in composing functions.\nHere is a list of functions in JSON format that you can invoke:\n[{"name": "newAddress"}]\nNO other text MUST be included.\n'
        pre_api, api_def, post_api = parse_system_prompt(prompt)
        self.assertEqual(
            pre_api,
            [
                "You are an expert in composing functions.",
                "Here is a list of functions in JSON format that you can invoke:",
            ],
        )
        self.assertEqual(api_def, ['[{"name": "newAddress"}]'])
        self.assertEqual(post_api, ["NO other text MUST be included."])

    def test_tool_ace_parse_api_request(self):
        input_str = '[Get Coin Price Difference(name="ETH", date="2010"), Get Coin Price Difference(name="MATIC", date="2010")]'
        result = tool_ace_parse_api_request(input_str)
        self.assertEqual(len(result), 2)
        self.assertEqual(result[0].api_name, "Get Coin Price Difference")
        self.assertEqual(result[1].api_name, "Get Coin Price Difference")
        self.assertEqual(sorted(result[0].params.keys()), sorted(["name", "date"]))
        self.assertEqual(sorted(result[1].params.keys()), sorted(["name", "date"]))
        self.assertEqual(sorted(result[0].params.values()), sorted(["ETH", "2010"]))
        self.assertEqual(sorted(result[1].params.values()), sorted(["MATIC", "2010"]))

        input_str = '[API1(name="A"), API2(name="B"), API3(name="C")]'
        result = tool_ace_parse_api_request(input_str)
        self.assertEqual(len(result), 3)
        self.assertEqual(result[0].api_name, "API1")
        self.assertEqual(result[1].api_name, "API2")
        self.assertEqual(result[2].api_name, "API3")

        self.assertEqual(sorted(result[0].params.keys()), sorted(["name"]))
        self.assertEqual(sorted(result[1].params.keys()), sorted(["name"]))
        self.assertEqual(sorted(result[2].params.keys()), sorted(["name"]))

        self.assertEqual(sorted(result[0].params.values()), sorted(["A"]))
        self.assertEqual(sorted(result[1].params.values()), sorted(["B"]))
        self.assertEqual(sorted(result[2].params.values()), sorted(["C"]))

        input_str = '[Get Coin Price Difference(name="ETH"'
        result = tool_ace_parse_api_request(input_str)
        self.assertIsNone(result)

        input_str = '[GetDifference(name=="ETH")]'
        result = tool_ace_parse_api_request(input_str)
        self.assertIsNone(result)

    def test_tool_ace_str_to_api_call(self):
        input_str = '[API_1(name="ETH", date="2010"), API_2(name="MATIC", date="2010")]'
        result = tool_ace_str_to_api_call(input_str)
        self.assertEqual(len(result), 2)
        self.assertEqual(result[0].api_name, "API_1")
        self.assertEqual(result[1].api_name, "API_2")

        input_str = "[API_1(name"
        result = tool_ace_str_to_api_call(input_str, False)
        self.assertIsNone(result)

        input_str = "[API_1(name"
        self.assertRaises(RuntimeError, tool_ace_str_to_api_call, input_str)

    def test_tool_ace_api_call_to_template(self):
        api_calls = [
            APICall(api_name="API_1", params={"param1": "value1", "param2": "value2"}),
            APICall(api_name="API_2", params={}),
        ]
        prompt = tool_ace_api_call_to_template(api_calls)
        self.assertEqual(
            prompt,
            [
                "Call the `API_1` API with following parameters: `param1` as `value1`, `param2` as `value2`",
                "Call the `API_2` API with no parameter",
            ],
        )

    def test_parse_conversation(self):
        conversation = [
            {
                "from": "user",
                "value": "I'm considering investing and I'd like to know what's happening in the market right now. Could you get me the top market trends in the US?",
            },
            {
                "from": "assistant",
                "value": '[Market Trends API(trend_type="MARKET_INDEXES", country="us")]',
            },
            {"from": "tool", "value": "NOT_USED"},
        ]
        conversation, truth_str, api_call, template_output = parse_conversation(
            conversation
        )
        self.assertEqual(
            conversation,
            [
                "user: I'm considering investing and I'd like to know what's happening in the market right now. Could you get me the top market trends in the US?"
            ],
        )
        self.assertEqual(
            truth_str, '[Market Trends API(trend_type="MARKET_INDEXES", country="us")]'
        )
        self.assertEqual(
            api_call,
            [
                APICall(
                    api_name="Market Trends API",
                    params={"trend_type": "MARKET_INDEXES", "country": "us"},
                )
            ],
        )
        self.assertEqual(
            template_output,
            [
                "Call the `Market Trends API` API with following parameters: `trend_type` as `MARKET_INDEXES`, `country` as `us`"
            ],
        )

    def test_raw_dict_to_record(self):
        raw_dict = {
            "system": 'You are an expert in composing functions.\nHere is a list of functions in JSON format that you can invoke:\n[{"name": "newAddress"}]\nNO other text MUST be included.\n',
            "conversations": [
                {
                    "from": "user",
                    "value": "I'm considering investing and I'd like to know what's happening in the market right now. Could you get me the top market trends in the US?",
                },
                {
                    "from": "assistant",
                    "value": '[Market Trends API(trend_type="MARKET_INDEXES", country="us")]',
                },
                {"from": "tool", "value": "NOT_USED"},
            ],
        }
        record = raw_dict_to_record(raw_dict)
        self.assertEqual(
            record.conversation,
            [
                "user: I'm considering investing and I'd like to know what's happening in the market right now. Could you get me the top market trends in the US?"
            ],
        )
        self.assertEqual(
            record.output,
            '[Market Trends API(trend_type="MARKET_INDEXES", country="us")]',
        )
        self.assertEqual(
            record.api_call.api_calls,
            [
                APICall(
                    api_name="Market Trends API",
                    params={"trend_type": "MARKET_INDEXES", "country": "us"},
                )
            ],
        )
        self.assertEqual(
            record.template_output,
            [
                "Call the `Market Trends API` API with following parameters: `trend_type` as `MARKET_INDEXES`, `country` as `us`"
            ],
        )

    def test_generate_id(self):
        random_text1 = "".join(random.choices(string.ascii_letters, k=256))
        id1 = generate_id(random_text1)
        random_text2 = "".join(random.choices(string.ascii_letters, k=256))
        id2 = generate_id(random_text2)
        self.assertNotEqual(id1, id2)
        id3 = generate_id(random_text1)
        self.assertEqual(id1, id3)


if __name__ == "__main__":
    unittest.main()
