import logging
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy as np
from sentence_transformers import SentenceTransformer, util

from src.utility import (
    EvalCountResult,
    Record,
    EvalResult,
    Status,
    APICallStatus,
    API_TOOL_CALL,
    API_INVALID_RESPONSE,
)

sf_model_name = "sentence-transformers/all-MiniLM-L6-v2"
logger = logging.getLogger(__name__)


def precision(correct: int, predicted: int) -> float:
    """
    Calculate precision.
    :param correct: Number of correct prediction.
    :param predicted: Number of predicted.
    :return: Precision value.
    """
    return correct / predicted if predicted > 0 else 0.0


def recall(correct: int, golden: int) -> float:
    """
    Calculate Recall.
    :param correct: Number of correct prediction.
    :param golden: Number of golden.
    :return: Recall value.
    """
    return correct / golden if golden > 0 else 0.0


def f1(p: float, r: float) -> float:
    """
    Calculate F1 score.
    :param p: Precision.
    :param r: Recall.
    :return: F1 score.
    """
    return 2 * p * r / (p + r) if p + r > 0 else 0.0


def compute_metrics(
    data: list[EvalResult],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Compute metrics based on evaluation results.
    :param data: List of EvalResult.
    :return: Tuple of: API precision, API recall, API F1, parameter precision, parameter recall, parameter F1.
    """
    api_precisions = []
    api_recalls = []
    api_f1s = []
    param_precisions = []
    param_recalls = []
    param_f1s = []
    for result in data:
        p = precision(result.correct_api_num, result.predicted_api_num)
        r = recall(result.correct_api_num, result.truth_api_num)
        f = f1(p, r)
        api_precisions.append(p)
        api_recalls.append(r)
        api_f1s.append(f)

        p = precision(result.correct_param_num, result.predicted_param_num)
        r = recall(result.correct_param_num, result.truth_param_num)
        f = f1(p, r)
        param_precisions.append(p)
        param_recalls.append(r)
        param_f1s.append(f)

    api_precisions = np.array(api_precisions)
    api_recalls = np.array(api_recalls)
    api_f1s = np.array(api_f1s)
    param_precisions = np.array(param_precisions)
    param_recalls = np.array(param_recalls)
    param_f1s = np.array(param_f1s)

    return (
        api_precisions,
        api_recalls,
        api_f1s,
        param_precisions,
        param_recalls,
        param_f1s,
    )


def calculate_sentence_similarity(source: str, target: str, model: SentenceTransformer):
    """
    Calculate cos similarity based on embedding of source and target.
    :param source: Source string.
    :param target: Target string.
    :param model: SentenceTransformer model.
    :return: Similarity value.
    """
    # Compute embedding for both lists
    try:
        embedding_1 = model.encode(
            source, convert_to_tensor=True, show_progress_bar=False
        )
        embedding_2 = model.encode(
            target, convert_to_tensor=True, show_progress_bar=False
        )
    except Exception as e:
        logger.error(f"Failed to process {source=}, {target=}. Error:{e}")
        return 0

    result = util.pytorch_cos_sim(embedding_1, embedding_2)
    return result.detach().item()


def match_value(
    predicted_param_value: str,
    truth_param_value: str,
    param_similarity_threshold: float,
    model: SentenceTransformer,
) -> Tuple[bool, float]:
    """
    Check two string values' similarity.
    :param predicted_param_value: Predicted value.
    :param truth_param_value: Truth value.
    :param param_similarity_threshold: Threshold for similarity.
    :param model: SentenceTransformer model.
    :return: Tuple of exceeding threshold or not and similarity value.
    """
    sim = calculate_sentence_similarity(predicted_param_value, truth_param_value, model)
    if sim >= param_similarity_threshold:
        return True, sim
    return False, sim


def eval_result(
    predictions: dict[str, APICallStatus],
    truth: dict[str, Record],
    param_similarity_threshold: float,
    device: Optional[str] = None,
) -> Tuple[EvalCountResult, list[EvalResult]]:
    """
    Eval the predictions.
    :param predictions: Dict of id and predicted APICall.
    :param truth: Dict of id and truth APICall.
    :param param_similarity_threshold: Threshold for similarity.
    :param device: Device to use when calculating semantic similarity.
    :return: Tuple of summarized result and detail results.
    """
    predicted_api_num = 0
    predicted_param_num = 0
    correct_api_num = 0
    correct_param_num = 0

    model = SentenceTransformer(model_name_or_path=sf_model_name, device=device)

    reports = []
    for id, record in truth.items():
        truth_items = record.api_call.api_calls

        total_truth_param_num = sum(len(x.params) for x in truth_items)
        if id not in predictions or predictions[id] is None:
            item = EvalResult(
                id=id,
                truth=record,
                prediction=None,
                status=Status.NoPrediction,
                predicted_api_num=0,
                correct_api_num=0,
                truth_api_num=len(truth_items),
                predicted_param_num=0,
                correct_param_num=0,
                truth_param_num=total_truth_param_num,
            )
            reports.append(item)
            continue

        if predictions[id].api_calls is None or len(predictions[id].api_calls) == 0:
            if record.api_call.api_call_status != API_TOOL_CALL:
                assert (
                    len(truth_items) == 0 and total_truth_param_num == 0
                ), f"Truth should have no API call for non-API call cases:{record}"
                if record.api_call.api_call_status != predictions[id].api_call_status:
                    status = Status.IncorrectAPIStatus
                else:
                    status = Status.Correct
            else:
                if (
                    predictions[id].api_call_status != API_TOOL_CALL
                    and predictions[id].api_call_status != API_INVALID_RESPONSE
                ):
                    status = Status.IncorrectAPIStatus
                else:
                    status = Status.NoPrediction
            item = EvalResult(
                id=id,
                truth=record,
                prediction=None,
                status=status,
                predicted_api_num=0,
                correct_api_num=0,
                truth_api_num=len(truth_items),
                predicted_param_num=0,
                correct_param_num=0,
                truth_param_num=total_truth_param_num,
            )
            reports.append(item)
            continue

        predict_api_calls = predictions[id].api_calls
        sample_predicted_api_num = 0
        sample_correct_api_num = 0
        sample_predicted_param_num = 0
        sample_correct_param_num = 0
        for idx, prediction in enumerate(predict_api_calls):
            predicted_api_num += 1
            predicted_param_num += len(prediction.params.keys())

            sample_predicted_api_num += 1
            sample_predicted_param_num += len(prediction.params.keys())

            # order of API call matters, prediction and truth should match 1:1
            if idx >= len(truth_items):
                continue

            truth_item = truth_items[idx]

            if truth_item.api_name != prediction.api_name:
                continue
            correct_api_num += 1
            sample_correct_api_num += 1

            predicted_parameters = prediction.params
            truth_parameters = truth_item.params

            for parameter_name in truth_parameters.keys():
                if parameter_name not in predicted_parameters:
                    continue

                predicted_param_value = str(predicted_parameters[parameter_name])
                truth_param_value = str(truth_parameters[parameter_name])
                if predicted_param_value == truth_param_value:
                    correct_param_num += 1
                    sample_correct_param_num += 1
                else:
                    matched, sim = match_value(
                        predicted_param_value,
                        truth_param_value,
                        param_similarity_threshold,
                        model,
                    )
                    if matched:
                        correct_param_num += 1
                        sample_correct_param_num += 1
                    else:
                        logger.debug(
                            f"value diff: score={sim}, {predicted_parameters=}, {truth_parameters=}"
                        )
        if record.api_call.api_call_status != API_TOOL_CALL:
            if record.api_call.api_call_status != predictions[id].api_call_status:
                item = EvalResult(
                    id=id,
                    truth=record,
                    prediction=predict_api_calls,
                    status=Status.IncorrectAPIStatus,
                    predicted_api_num=sample_predicted_api_num,
                    correct_api_num=sample_correct_api_num,
                    truth_api_num=len(truth_items),
                    predicted_param_num=sample_predicted_param_num,
                    correct_param_num=sample_correct_param_num,
                    truth_param_num=total_truth_param_num,
                )
            else:
                item = EvalResult(
                    id=id,
                    truth=record,
                    prediction=predict_api_calls,
                    status=Status.Correct,
                    predicted_api_num=sample_predicted_api_num,
                    correct_api_num=sample_correct_api_num,
                    truth_api_num=len(truth_items),
                    predicted_param_num=sample_predicted_param_num,
                    correct_param_num=sample_correct_param_num,
                    truth_param_num=total_truth_param_num,
                )
        elif sample_correct_api_num != len(truth_items):
            item = EvalResult(
                id=id,
                truth=record,
                prediction=predict_api_calls,
                status=Status.IncorrectAPI,
                predicted_api_num=sample_predicted_api_num,
                correct_api_num=sample_correct_api_num,
                truth_api_num=len(truth_items),
                predicted_param_num=sample_predicted_param_num,
                correct_param_num=sample_correct_param_num,
                truth_param_num=total_truth_param_num,
            )
        elif sample_correct_param_num != total_truth_param_num:
            name_match = True
            for i in range(len(truth_items)):
                if sorted(truth_items[i].params.keys()) != sorted(
                    predict_api_calls[i].params.keys()
                ):
                    name_match = False
                    break
            if name_match:
                item = EvalResult(
                    id=id,
                    truth=record,
                    prediction=predict_api_calls,
                    status=Status.IncorrectParamValue,
                    predicted_api_num=sample_predicted_api_num,
                    correct_api_num=sample_correct_api_num,
                    truth_api_num=len(truth_items),
                    predicted_param_num=sample_predicted_param_num,
                    correct_param_num=sample_correct_param_num,
                    truth_param_num=total_truth_param_num,
                )
            else:
                item = EvalResult(
                    id=id,
                    truth=record,
                    prediction=predict_api_calls,
                    status=Status.IncorrectParam,
                    predicted_api_num=sample_predicted_api_num,
                    correct_api_num=sample_correct_api_num,
                    truth_api_num=len(truth_items),
                    predicted_param_num=sample_predicted_param_num,
                    correct_param_num=sample_correct_param_num,
                    truth_param_num=total_truth_param_num,
                )
        else:
            item = EvalResult(
                id=id,
                truth=record,
                prediction=predict_api_calls,
                status=Status.Correct,
                predicted_api_num=sample_predicted_api_num,
                correct_api_num=sample_correct_api_num,
                truth_api_num=len(truth_items),
                predicted_param_num=sample_predicted_param_num,
                correct_param_num=sample_correct_param_num,
                truth_param_num=total_truth_param_num,
            )
        reports.append(item)

    eval_count = EvalCountResult(
        predicted_api_num=predicted_api_num,
        predicted_param_num=predicted_param_num,
        correct_api_num=correct_api_num,
        correct_param_num=correct_param_num,
    )
    return eval_count, reports


@dataclass
class PerSampleData:
    api_p: np.ndarray
    api_r: np.ndarray
    api_f1: np.ndarray
    param_p: np.ndarray
    param_r: np.ndarray
    param_f1: np.ndarray


def calculate_score(
    truth: dict[str, Record],
    predictions: dict[str, APICallStatus],
    param_similarity_threshold: float,
    device: Optional[str] = None,
) -> Tuple[dict, list[EvalResult], PerSampleData]:
    """
    Calculate metrics of the prediction.
    :param truth: Truth.
    :param predictions: Predictions.
    :param param_similarity_threshold: Threshold for the semantic similarity.
    :param device: Device to use when calculating semantic similarity.
    :return: Tuple of detail result and summary.
    """
    # truth = load_truth_file(truth_path)
    # predictions = load_prediction_file(prediction_path)

    golden_api_num = 0
    golden_param_num = 0
    for _, item in truth.items():
        for api_call in item.api_call.api_calls:
            golden_api_num += 1
            golden_param_num += len(api_call.params.keys())

    correct_api_num = 0
    predicted_api_num = 0

    correct_param_num = 0
    predicted_param_num = 0

    result_count, reports = eval_result(
        predictions=predictions,
        truth=truth,
        param_similarity_threshold=param_similarity_threshold,
        device=device,
    )
    correct_api_num += result_count.correct_api_num
    predicted_api_num += result_count.predicted_api_num

    correct_param_num += result_count.correct_param_num
    predicted_param_num += result_count.predicted_param_num

    # verify the total numbers are matched
    assert result_count.correct_api_num == sum([x.correct_api_num for x in reports])
    assert result_count.predicted_api_num == sum([x.predicted_api_num for x in reports])
    assert golden_api_num == sum([x.truth_api_num for x in reports])
    assert result_count.correct_param_num == sum([x.correct_param_num for x in reports])
    assert result_count.predicted_param_num == sum(
        [x.predicted_param_num for x in reports]
    )
    assert golden_param_num == sum([x.truth_param_num for x in reports])

    result_dict: dict[str, Union[int, float, str]] = {
        "golden_api_num": golden_api_num,
        "golden_param_num": golden_param_num,
        "predicted_api_num": predicted_api_num,
        "predicted_param_num": predicted_param_num,
        "correct_api_num": correct_api_num,
        "correct_param_num": correct_param_num,
    }
    if predicted_api_num > 0:
        result_dict["prediction_amount"] = 1.0 * predicted_api_num / len(truth)
    else:
        result_dict["prediction_amount"] = 0

    if correct_api_num * predicted_api_num * golden_api_num > 0:
        result_dict["micro_averaged_P_api"] = 1.0 * correct_api_num / predicted_api_num
        result_dict["micro_averaged_R_api"] = 1.0 * correct_api_num / golden_api_num
        result_dict["micro_averaged_F1_api"] = (
            2
            * result_dict["micro_averaged_P_api"]
            * result_dict["micro_averaged_R_api"]
            / (
                result_dict["micro_averaged_P_api"]
                + result_dict["micro_averaged_R_api"]
            )
        )
    else:
        result_dict["micro_averaged_P_api"] = 0
        result_dict["micro_averaged_R_api"] = 0
        result_dict["micro_averaged_F1_api"] = 0

    if correct_param_num * predicted_param_num * golden_param_num > 0:
        result_dict["micro_averaged_P_param"] = (
            1.0 * correct_param_num / predicted_param_num
        )
        result_dict["micro_averaged_R_param"] = (
            1.0 * correct_param_num / golden_param_num
        )
        result_dict["micro_averaged_F1_param"] = (
            2
            * result_dict["micro_averaged_P_param"]
            * result_dict["micro_averaged_R_param"]
            / (
                result_dict["micro_averaged_P_param"]
                + result_dict["micro_averaged_R_param"]
            )
        )
    else:
        result_dict["micro_averaged_P_param"] = 0
        result_dict["micro_averaged_R_param"] = 0
        result_dict["micro_averaged_F1_param"] = 0

    (
        api_p,
        api_r,
        api_f1,
        param_p,
        param_r,
        param_f1,
    ) = compute_metrics(reports)

    result_dict["marco_averaged_P_api"] = float(np.mean(api_p))
    result_dict["marco_averaged_R_api"] = float(np.mean(api_r))
    result_dict["marco_averaged_F1_api"] = float(np.mean(api_f1))
    result_dict["marco_averaged_P_param"] = float(np.mean(param_p))
    result_dict["marco_averaged_R_param"] = float(np.mean(param_r))
    result_dict["marco_averaged_F1_param"] = float(np.mean(param_f1))

    per_sample_data = PerSampleData(
        api_p=api_p,
        api_r=api_r,
        api_f1=api_f1,
        param_p=param_p,
        param_r=param_r,
        param_f1=param_f1,
    )

    total = 0
    api_accuracy = 0
    for key, item in truth.items():
        total += 1
        if key not in predictions:
            continue
        prediction = predictions[key]
        if item.api_call.api_call_status == prediction.api_call_status:
            api_accuracy += 1
    logger.info(f"API acc:{api_accuracy/total}")
    return (
        result_dict,
        reports,
        per_sample_data,
    )
