import json
import logging
from dataclasses import asdict
from typing import Optional

from src.utility import (
    Record,
    dataset_when2call,
    APICall,
    API_CAN_NOT_ANSWER,
    API_REQUEST_FOR_INFO,
    APICallStatus,
    API_TOOL_CALL,
    API_INVALID_RESPONSE,
    generate_id,
)

logger = logging.getLogger(__name__)


def when2call_convert(
    jsonl_file: str,
    out_file: str,
    filter_data: bool,
    data_type: Optional[str] = None,
):
    """
    Convert original When2Call data to unified format.
    :param jsonl_file: Input When2Call data file.
    :param out_file: Converted When2Call data file.
    :param filter_data: Remove non-API call data.
    :param data_type: The data file is train data or test data.
    """
    results = []
    with open(jsonl_file, "r") as fin:
        count = 0
        for line in fin:
            count += 1
            record = json.loads(line)
            if data_type is None or data_type != "train":
                new_record = when2call_process_test_data(
                    record, raise_error_on_failure=False
                )
            else:
                new_record = when2call_process_train_data(record)
            if (
                new_record is None
                or new_record.api_call.api_call_status == API_INVALID_RESPONSE
            ):
                continue
            if filter_data and new_record.api_call.api_call_status != API_TOOL_CALL:
                continue
            results.append(new_record)
    logger.info(f"Original data size: {count}")
    logger.info(f"Processed data size: {len(results)}")

    with open(out_file, "w") as fout:
        json.dump(
            [asdict(x) for x in results],
            fout,
            ensure_ascii=False,
            indent=2,
        )


def when2call_process_train_data(record: dict) -> Optional[Record]:
    """
    Convert original When2Call training data to unified format.
    :param record: Dict containing a training instance.
    :return: Training instance in Record format.
    """
    if "tools" not in record or len(record["tools"]) == 0:
        return None
    if (
        "chosen_response" not in record
        or "messages" not in record
        or len(record["messages"]) != 1
    ):
        logger.error(f"Invalid When2Call train record: {record}")
        return None
    tool_def = record["tools"]
    conversation = record["messages"][0]["content"]
    content = record["chosen_response"]["content"]
    if content.startswith("<TOOLCALL>") and content.endswith("</TOOLCALL>"):
        tool_call = content.removeprefix("<TOOLCALL>").removesuffix("</TOOLCALL>")
        tool_call_json = json.loads(tool_call)
        if isinstance(tool_call_json, list) and len(tool_call_json) > 1:
            raise RuntimeError(f"Invalid When2Call train record: {record}")
        name = tool_call_json[0]["name"]
        params = tool_call_json[0]["arguments"]
        api_call = APICallStatus(
            api_call_status=API_TOOL_CALL,
            api_calls=[APICall(api_name=name, params=params)],
        )
        output = json.dumps(tool_call_json[0], ensure_ascii=False)
    elif (
        "specify" in content
        or "need" in content
        or "provide" in content
        or "which" in content.lower()
        or "how many" in content.lower()
        or "may i know" in content.lower()
        or "what" in content.lower()
    ):
        api_call = APICallStatus(api_calls=[], api_call_status=API_REQUEST_FOR_INFO)
        output = API_REQUEST_FOR_INFO
    elif (
        "unable" in content
        or "can't" in content
        or "I don't have" in content
        or "apologies" in content.lower()
        or "I'm sorry" in content
    ):
        api_call = APICallStatus(api_calls=[], api_call_status=API_CAN_NOT_ANSWER)
        output = API_CAN_NOT_ANSWER
    else:
        raise RuntimeError(f"Unknown When2Call record type: {record}")
    template_output = when2call_api_call_to_template(api_call)

    record = Record(
        id=generate_id(str(tool_def) + str(conversation) + str(content)),
        data_set=dataset_when2call,
        pre_api=[],
        api_def=tool_def,
        post_api=[],
        conversation=[conversation],
        ending=[],
        output=output,
        api_call=api_call,
        template_output=template_output,
    )

    return record


def when2call_process_test_data(
    record: dict, raise_error_on_failure: bool = True
) -> Optional[Record]:
    """
    Convert original When2Call test data to unified format.
    :param record: Dict containing a test instance.
    :param raise_error_on_failure: Raise error when the failed to parse.
    :return: Test instance in Record format.
    """
    assert "question" in record, f"Unexpected record: {record}"
    assert isinstance(record["question"], str), f"Unexpected question: {record}"
    if record["correct_answer"] == "tool_call":
        output = record["answers"]["tool_call"]
        api_call = when2call_str_to_api_call(output, raise_error_on_failure)
        if api_call.api_call_status == API_INVALID_RESPONSE:
            return None
        template_output = when2call_api_call_to_template(api_call)
    elif record["correct_answer"] == "cannot_answer":
        output = API_CAN_NOT_ANSWER
        api_call = APICallStatus(
            api_call_status=API_CAN_NOT_ANSWER,
            api_calls=[],
        )
        template_output = [API_CAN_NOT_ANSWER]
    elif record["correct_answer"] == "request_for_info":
        output = API_REQUEST_FOR_INFO
        api_call = APICallStatus(
            api_call_status=API_REQUEST_FOR_INFO,
            api_calls=[],
        )
        template_output = [API_REQUEST_FOR_INFO]
    else:
        raise RuntimeError(f"Unexpected When2Call answer type: {record}")
    record = Record(
        id=record["uuid"],
        data_set=dataset_when2call,
        pre_api=[],
        api_def=record["tools"],
        post_api=[],
        conversation=[record["question"]],
        ending=[],
        output=output,
        api_call=api_call,
        template_output=template_output,
    )
    return record


def when2call_str_to_api_call(
    input_str: str, raise_error_on_failure: bool
) -> APICallStatus:
    """
    Convert a string to APICallStatus object.
    :param input_str: Input string containing API call information.
    :param raise_error_on_failure: Raise error if failed to parse APICall object.
    :return: APICallStatus object containing API call information.
    """
    if API_CAN_NOT_ANSWER in input_str:
        return APICallStatus(api_call_status=API_CAN_NOT_ANSWER, api_calls=[])
    if API_REQUEST_FOR_INFO in input_str:
        return APICallStatus(api_call_status=API_REQUEST_FOR_INFO, api_calls=[])

    if input_str.startswith("```json") and input_str.endswith("```"):
        input_str = input_str.removeprefix("```json").removesuffix("```").strip()
    try:
        api_obj = json.loads(input_str)
        api_call = APICall(
            api_name=api_obj["name"],
            params=api_obj["arguments"],
        )
        return APICallStatus(api_call_status=API_TOOL_CALL, api_calls=[api_call])
    except Exception as e:
        logger.info(f"Unable to parse API: {e} : {input_str}")
        if raise_error_on_failure:
            raise RuntimeError(f"Unable to parse API: {input_str}")
    return APICallStatus(api_call_status=API_INVALID_RESPONSE, api_calls=[])


when2call_no_param_format = "Call the `{}` API with no parameter"
when2call_prefix_format = "Call the `{}` API with following parameters: "
when2call_param_format = "`{}` as `{}`"


def when2call_api_call_to_template(api_call_status: APICallStatus) -> list[str]:
    """
    Convert APICallStatus object to a template based string.
    :param api_call_status: APICallStatus object to convert.
    :return: Template based string to describe APICall.
    """
    if api_call_status.api_call_status == API_CAN_NOT_ANSWER:
        return [API_CAN_NOT_ANSWER]
    if api_call_status.api_call_status == API_REQUEST_FOR_INFO:
        return [API_REQUEST_FOR_INFO]
    if api_call_status.api_call_status == API_INVALID_RESPONSE:
        raise RuntimeError(f"Invalid APICallStatus: {api_call_status}")
    lines = []
    for api_call in api_call_status.api_calls:
        if len(api_call.params) == 0:
            result = when2call_no_param_format.format(api_call.api_name)
        else:
            param_values = []
            for k, v in api_call.params.items():
                param_values.append(when2call_param_format.format(k, v))
            param_str = ", ".join(param_values)
            result = when2call_prefix_format.format(api_call.api_name) + param_str
        lines.append(result)
    return lines
