import json
import logging
import re

from src.api_bank_data_conversion import api_bank_str_to_api_call
from src.tool_ace_data_conversion import tool_ace_str_to_api_call
from src.utility import (
    APICallStatus,
    dataset_api_bank,
    API_TOOL_CALL,
    API_INVALID_RESPONSE,
    dataset_tool_ace,
    dataset_when2call,
    Record,
    ModelResponse,
    API_CAN_NOT_ANSWER,
    API_REQUEST_FOR_INFO,
)
from src.when2call_data_conversion import when2call_str_to_api_call

id_key = "id"
dataset_key = "dataset"
apis_key = "apis"

pattern_api_no_arg = re.compile(r"Call the `([^`]+)` API with no parameter")
pattern_api = re.compile(r"Call the `([^`]+)` API with following parameters:.*")
pattern_values = re.compile(r"`([^`]+)` as `([^`]+)`")

logger = logging.getLogger(__name__)


def str_to_api_call(
    input_str: str, dataset: str, raise_error_on_failure: bool = True
) -> APICallStatus:
    """
    Convert a string to APICallStatus object.
    :param input_str: Input string containing API call information.
    :param dataset: Dataset name.
    :param raise_error_on_failure: Raise error if failed to parse APICall object.
    :return: APICallStatus object containing API call information.
    """
    if dataset == dataset_api_bank:
        api_call = api_bank_str_to_api_call(input_str, raise_error_on_failure)
        if api_call is not None:
            return APICallStatus(api_call_status=API_TOOL_CALL, api_calls=[api_call])
        return APICallStatus(api_call_status=API_INVALID_RESPONSE, api_calls=[])
    elif dataset == dataset_tool_ace:
        api_calls = tool_ace_str_to_api_call(input_str, raise_error_on_failure)
        if api_calls is not None:
            return APICallStatus(api_call_status=API_TOOL_CALL, api_calls=api_calls)
        return APICallStatus(api_call_status=API_INVALID_RESPONSE, api_calls=[])
    elif dataset == dataset_when2call:
        api_calls = when2call_str_to_api_call(input_str, raise_error_on_failure)
        return api_calls
    else:
        raise RuntimeError(f"Unknown data set:{dataset}")


def load_truth_file(truth_file_path: str) -> dict[str, Record]:
    """
    Load truth file containing prompts and expected results.
    :param truth_file_path: Path of truth file.
    :return: A dict with id as key and Record as value.
    """
    results = {}
    with open(truth_file_path, "r") as fin:
        truth_json = json.load(fin)
        for item in truth_json:
            record = Record.from_dict(item)
            results[record.id] = record
    return results


def load_prediction_file(
    pred_file_path: str, intermediate_result: bool = False
) -> dict[str, APICallStatus | ModelResponse]:
    """
    Load prediction file contains the response of models.
    :param pred_file_path: Prediction file path.
    :param intermediate_result: Is the response intermediate result or not.
    Intermediate result won't be parsed to APICall. Instead, it will be used as prompts to next LLM call.
    :return: A dict with id as key and APICall as value.
    """
    results = {}
    with open(pred_file_path, "r") as fin:
        for line in fin:
            item = json.loads(line)
            model_response = ModelResponse(**item)
            key = model_response.id
            data_set = model_response.data_set
            response = model_response.response
            if intermediate_result:
                results[key] = model_response
            else:
                response_call = str_to_api_call(response, data_set, False)
                results[key] = response_call
    return results


def api_bank_api_call_status_to_api_str(api_call: APICallStatus) -> str:
    """
    Convert a dict to API call string.
    :param api_call: Dict of API call information.
    :return: String containing API call.
    """
    api_name = api_call.api_calls[0].api_name
    line = f"API-Request: [{api_name}("
    params = []
    for key, value in api_call.api_calls[0].params.items():
        params.append(f"{key}='{value}'")
    line += ", ".join(params)
    line += ")]"
    return line


def tool_ace_api_call_status_to_api_str(api_call: APICallStatus) -> str:
    """
    Convert a dict to API call string.
    :param api_call: Dict of API call information.
    :return: String containing API call.
    """
    apis = []
    for obj in api_call.api_calls:
        api_name = obj.api_name
        api = f"{api_name}("
        params = []
        for key, value in obj.params.items():
            params.append(f'{key}="{value}"')
        api += ", ".join(params)
        api += ")"
        apis.append(api)
    line = "["
    line += ", ".join(apis)
    line += "]"
    return line


def when2call_api_call_status_to_api_str(input_value: APICallStatus) -> str:
    """
    Convert a dict to API call string.
    :param input_value: Dict of API call information.
    :return: String containing API call.
    """
    if input_value.api_call_status == API_INVALID_RESPONSE:
        raise RuntimeError(f"Invalid record: {input_value}")
    if (
        input_value.api_call_status == API_CAN_NOT_ANSWER
        or input_value.api_call_status == API_REQUEST_FOR_INFO
    ):
        return input_value.api_call_status

    assert (
        len(input_value.api_calls) == 1
    ), f"when2call should have only 1 API call: {input_value}"
    api_name = input_value.api_calls[0].api_name
    params = input_value.api_calls[0].params
    obj = {
        "name": api_name,
        "arguments": params,
    }
    return json.dumps(obj)


def api_call_status_to_api_str(data_set: str, api_call: APICallStatus) -> str:
    """
    Convert a dict to API call string.
    :param data_set: Dataset of record.
    :param api_call: List of dict containing API call information.
    :return: String containing API call.
    """
    if data_set == dataset_api_bank:
        # API Bank should have only one API call for each record
        if len(api_call.api_calls) != 1:
            return ""
        # assert len(api_call.api_calls) == 1, f"Invalid record for API Bank:{record_id}"
        return api_bank_api_call_status_to_api_str(api_call)
    elif data_set == dataset_tool_ace:
        return tool_ace_api_call_status_to_api_str(api_call)
    elif data_set == dataset_when2call:
        return when2call_api_call_status_to_api_str(api_call)
    else:
        raise RuntimeError(f"Unknown dataset:{data_set}")


def api_bank_tool_ace_parse_template(line: str) -> APICallStatus:
    """
    Parse a text to APICallStatus object for API-Bank and ToolACE.
    :param line: A line containing information for API call.
    :return: APICallStatus
    """
    apis = []
    lines = line.split("\n")
    for resp in lines:
        match = pattern_api_no_arg.match(resp)
        if match:
            api = match.group(1)
            apis.append({"api_name": api, "params": {}})
            continue
        match = pattern_api.match(resp)
        if match:
            api = match.group(1)
            api_call = {"api_name": api, "params": {}}
            # values = re.findall(pattern_values, resp)
            values = pattern_values.findall(resp)
            for param_name, param_value in values:
                param_value = param_value.strip()
                api_call["params"][param_name] = param_value
            apis.append(api_call)
        else:
            logger.info(f"Unable to match template: {resp}")
    return APICallStatus.from_dict(
        {"api_call_status": API_TOOL_CALL, "api_calls": apis}
    )


def when2call_parse_template(line: str) -> APICallStatus:
    """
    Parse a text to APICallStatus object for When2Call.
    :param line: A line containing information for API call.
    :return: APICallStatus
    """
    apis = []
    if API_CAN_NOT_ANSWER in line:
        return APICallStatus(api_call_status=API_CAN_NOT_ANSWER, api_calls=[])
    if API_REQUEST_FOR_INFO in line:
        return APICallStatus(api_call_status=API_REQUEST_FOR_INFO, api_calls=[])
    lines = [x.strip() for x in line.split("\n")]
    for resp in lines:
        if len(resp) == 0:
            continue
        match = pattern_api_no_arg.match(resp)
        if match:
            api = match.group(1)
            apis.append({"api_name": api, "params": {}})
            continue
        match = pattern_api.match(resp)
        if match:
            api = match.group(1)
            api_call = {"api_name": api, "params": {}}
            values = pattern_values.findall(line)
            for param_name, param_value in values:
                param_value = param_value.strip()
                api_call["params"][param_name] = param_value
            apis.append(api_call)
        else:
            logger.info(f"Unable to match: {resp}")
    if len(apis) == 0 or len(apis) > 1:
        logger.info(f"When2Call should have 1 API call: {line}")
        return APICallStatus(api_call_status=API_INVALID_RESPONSE, api_calls=[])
    return APICallStatus.from_dict(
        {"api_call_status": API_TOOL_CALL, "api_calls": apis}
    )


def template_to_api_call(line: str, data_set: str) -> APICallStatus:
    """
    Parse a text to APICallStatus object  based on the dataset.
    :param line: A line containing information for API call.
    :param data_set: Dataset name.
    :return: APICallStatus
    """
    if data_set == dataset_api_bank or data_set == dataset_tool_ace:
        return api_bank_tool_ace_parse_template(line)
    elif data_set == dataset_when2call:
        return when2call_parse_template(line)
    else:
        raise RuntimeError(f"Unknown data set:{data_set}")
