import json
import logging
import random
import re
from dataclasses import asdict
from typing import Tuple, Optional

from src.utility import (
    Record,
    APICall,
    dataset_tool_ace,
    Section,
    APICallStatus,
    API_TOOL_CALL,
    generate_id,
)

logger = logging.getLogger(__name__)


def split_raw_file(
    json_file: str,
    out_train_file: str,
    out_test_file: str,
    test_set_size: int = 1000,
    random_seed: int = 666,
):
    """
    Load truth file and split to train/test set.
    :param json_file: Path of raw data file.
    :param out_train_file: Path of output train file.
    :param out_test_file: Path of output test file.
    :param test_set_size: Size of test set.
    :param random_seed: Random seed.
    """
    with open(json_file, "r") as fin:
        objects = json.load(fin)
    logger.info(f"Total records: {len(objects)}")
    random.seed(random_seed)
    random.shuffle(objects)

    records = [raw_dict_to_record(x) for x in objects]
    records = [x for x in records if x is not None]
    logger.info(f"Valid records: {len(records)}")

    if len(records) < test_set_size:
        logger.error(
            f"Total record size is smaller than test set size: {len(records)} < {test_set_size}"
        )
        return

    with open(out_train_file, "w") as fout:
        json.dump(
            [asdict(x) for x in records[test_set_size:]],
            fout,
            ensure_ascii=False,
            indent=2,
        )

    with open(out_test_file, "w") as fout:
        json.dump(
            [asdict(x) for x in records[:test_set_size]],
            fout,
            ensure_ascii=False,
            indent=2,
        )


def parse_system_prompt(system_prompt: str) -> Tuple[list, list, list]:
    """
    Parse system prompt.
    :param system_prompt: A string of system prompt.
    :return: Tuple of: lines before API definition, lines of API definition and lines after API definition.
    """
    lines = [x.strip() for x in system_prompt.split("\n")]
    pre_api = []
    api_def = []
    post_api = []
    section = Section.PreAPI
    for line in lines:
        if len(line) == 0:
            continue
        if section == Section.PreAPI:
            pre_api.append(line)
        elif section == Section.APIDef:
            api_def.append(line)
        elif section == Section.PostAPI:
            post_api.append(line)
        else:
            raise RuntimeError(f"Unexpected section:{section}")

        last_line = "Here is a list of functions in JSON format that you can invoke:"
        if line.endswith(last_line) or line.startswith(last_line):
            section = Section.APIDef
        elif line.startswith('[{"name":'):
            section = Section.PostAPI
    return pre_api, api_def, post_api


tool_ace_pattern = re.compile(r"(?P<api>[^, ][^(]+)\((?P<param>[^)]*)\)")
tool_ace_api_pattern = re.compile(
    r"\[(?P<api>[^(]+)\((?P<param>[^)]*)\)(?:, ([^(]+)\(([^)]*)\))*]"
)
tool_ace_param_name_pattern = re.compile(r"(\w+)=")


def tool_ace_parse_api_request(api_request: str) -> Optional[list[APICall]]:
    """
    Parse the API response string to create APICall objects.
    :param api_request: String based API response.
    :return: List of APICall objects.
    """
    # input is like: [Market Trends API(trend_type="MARKET_INDEXES", country="us")]
    match = tool_ace_api_pattern.match(api_request)
    if match is None:
        logger.error("Unable to match API request for ToolACE:" + api_request)
        return None

    results = []
    api_request = api_request.removeprefix("[").removesuffix("]")
    values = re.findall(tool_ace_pattern, api_request)
    for api_name, params in values:
        # add quote to name: "trend_type"="MARKET_INDEXES", "country"="us"
        params = tool_ace_param_name_pattern.sub(r'"\1"=', params)
        # replace quote with colon: trend_type:"MARKET_INDEXES", country:"us"
        params = params.replace("=", ":")
        # add curly braces
        params = "{" + params + "}"
        try:
            # parse with json
            params_json = json.loads(json.dumps(eval(params)))
        except:
            logger.info(f"Unable to parse parameters for ToolACE: {params}")
            return None
        api_call = APICall(
            api_name=api_name,
            params=params_json,
        )
        results.append(api_call)
    if len(results) == 0:
        logger.info(f"unable to parse {api_request}")
        return None
    return results


def tool_ace_str_to_api_call(
    input_str: str, raise_error_on_failure: bool = True
) -> Optional[list[APICall]]:
    """
    Convert a string to APICallStatus object.
    :param input_str: String based API call information.
    :param raise_error_on_failure: Raise error if failed to parse APICall object.
    :return: Parsed APICall object.
    """
    api_call = tool_ace_parse_api_request(input_str)
    if api_call is None:
        message = f"Failed to parse string to ToolACE API call: {input_str}"
        logger.error(message)
        if raise_error_on_failure:
            raise RuntimeError(message)
    return api_call


tool_ace_no_param_format = "Call the `{}` API with no parameter"
tool_ace_prefix_format = "Call the `{}` API with following parameters: "
tool_ace_param_format = "`{}` as `{}`"


def tool_ace_api_call_to_template(api_calls: list[APICall]) -> list[str]:
    """
    Convert APICall object to template based strings.
    :param api_calls: List of APICall objects to convert.
    :return: Template based strings to describe APICall.
    """
    lines = []
    for api_call in api_calls:
        if len(api_call.params) == 0:
            result = tool_ace_no_param_format.format(api_call.api_name)
        else:
            param_values = []
            for k, v in api_call.params.items():
                param_values.append(tool_ace_param_format.format(k, v))
            param_str = ", ".join(param_values)
            result = tool_ace_prefix_format.format(api_call.api_name) + param_str
        lines.append(result)
    return lines


def parse_conversation(input_list: list[dict]) -> Tuple[list, str, list, list]:
    """
    Parse conversation prompts.
    :param input_list: List of conversation.
    :return: Tuple of: list of conversation string, truth, list of APICall, template based strings
    """
    user_turn = input_list[0]
    assistant_turn = input_list[1]
    assert "user" == user_turn["from"], f'Key "user" not in {user_turn}'
    assert (
        "assistant" == assistant_turn["from"]
    ), f'Key "assistant" not in {assistant_turn}'

    conversation = [f'user: {user_turn["value"]}']
    truth_str = assistant_turn["value"]
    api_call = tool_ace_str_to_api_call(truth_str, raise_error_on_failure=False)
    template_output = None
    if api_call is not None:
        template_output = tool_ace_api_call_to_template(api_call)
    return conversation, truth_str, api_call, template_output


def raw_dict_to_record(input_dict: dict) -> Optional[Record]:
    """
    Convert a dict to Record object.
    :param input_dict: A dict containing API call information.
    :return: Record object.
    """
    assert "system" in input_dict, f'Key "system" does not exist: {input_dict}'
    assert (
        "conversations" in input_dict
    ), f'Key "conversations" does not exist: {input_dict}'
    pre_api, api_def, post_api = parse_system_prompt(input_dict["system"])
    if len(api_def) == 0:
        return None
    conversation, truth_str, api_call, template_output = parse_conversation(
        input_dict["conversations"]
    )
    if api_call is None:
        return None
    record = Record(
        id=generate_id(str(input_dict["system"]) + str(input_dict["conversations"])),
        data_set=dataset_tool_ace,
        pre_api=pre_api,
        api_def=api_def,
        post_api=post_api,
        conversation=conversation,
        ending=[],
        output=truth_str,
        api_call=APICallStatus(api_call_status=API_TOOL_CALL, api_calls=api_call),
        template_output=template_output,
    )

    return record
