import random
import unittest

import numpy as np
import numpy.testing as npt
from sentence_transformers import SentenceTransformer

from src.evaluation import (
    precision,
    recall,
    f1,
    compute_metrics,
    calculate_sentence_similarity,
    sf_model_name,
    match_value,
    eval_result,
    calculate_score,
)
from src.utility import (
    EvalResult,
    Record,
    Status,
    APICall,
    APICallStatus,
    API_TOOL_CALL,
    API_INVALID_RESPONSE,
    API_CAN_NOT_ANSWER,
    API_REQUEST_FOR_INFO,
)

eval_result1 = EvalResult(
    id="1",
    truth=Record(
        id="",
        data_set="",
        output="",
        pre_api=[],
        post_api=[],
        ending=[],
        conversation=[],
        api_def=[],
        api_call=APICallStatus(api_call_status=API_TOOL_CALL, api_calls=[]),
    ),
    prediction=None,
    status=Status.NoPrediction,
    predicted_api_num=0,
    correct_api_num=0,
    truth_api_num=2,
    predicted_param_num=0,
    correct_param_num=0,
    truth_param_num=4,
)
eval_result2 = EvalResult(
    id="2",
    truth=Record(
        id="",
        data_set="",
        output="",
        pre_api=[],
        post_api=[],
        ending=[],
        conversation=[],
        api_def=[],
        api_call=APICallStatus(api_call_status=API_TOOL_CALL, api_calls=[]),
    ),
    prediction=None,
    status=Status.IncorrectAPI,
    predicted_api_num=1,
    correct_api_num=0,
    truth_api_num=2,
    predicted_param_num=3,
    correct_param_num=0,
    truth_param_num=4,
)
eval_result3 = EvalResult(
    id="3",
    truth=Record(
        id="",
        data_set="",
        output="",
        pre_api=[],
        post_api=[],
        ending=[],
        conversation=[],
        api_def=[],
        api_call=APICallStatus(api_call_status=API_TOOL_CALL, api_calls=[]),
    ),
    prediction=None,
    status=Status.IncorrectParam,
    predicted_api_num=2,
    correct_api_num=1,
    truth_api_num=2,
    predicted_param_num=2,
    correct_param_num=1,
    truth_param_num=4,
)
eval_result4 = EvalResult(
    id="3",
    truth=Record(
        id="",
        data_set="",
        output="",
        pre_api=[],
        post_api=[],
        ending=[],
        conversation=[],
        api_def=[],
        api_call=APICallStatus(api_call_status=API_TOOL_CALL, api_calls=[]),
    ),
    prediction=None,
    status=Status.IncorrectParam,
    predicted_api_num=2,
    correct_api_num=2,
    truth_api_num=2,
    predicted_param_num=2,
    correct_param_num=1,
    truth_param_num=4,
)
eval_result5 = EvalResult(
    id="4",
    truth=Record(
        id="",
        data_set="",
        output="",
        pre_api=[],
        post_api=[],
        ending=[],
        conversation=[],
        api_def=[],
        api_call=APICallStatus(api_call_status=API_TOOL_CALL, api_calls=[]),
    ),
    prediction=None,
    status=Status.IncorrectParamValue,
    predicted_api_num=2,
    correct_api_num=2,
    truth_api_num=2,
    predicted_param_num=5,
    correct_param_num=2,
    truth_param_num=4,
)
eval_result6 = EvalResult(
    id="5",
    truth=Record(
        id="",
        data_set="",
        output="",
        pre_api=[],
        post_api=[],
        ending=[],
        conversation=[],
        api_def=[],
        api_call=APICallStatus(api_call_status=API_TOOL_CALL, api_calls=[]),
    ),
    prediction=None,
    status=Status.Correct,
    predicted_api_num=2,
    correct_api_num=2,
    truth_api_num=2,
    predicted_param_num=4,
    correct_param_num=4,
    truth_param_num=4,
)
eval_results = [
    eval_result1,
    eval_result2,
    eval_result3,
    eval_result4,
    eval_result5,
    eval_result6,
]

predictions = {
    "1": APICallStatus(api_call_status=API_INVALID_RESPONSE, api_calls=None),
    "2": APICallStatus(api_call_status=API_INVALID_RESPONSE, api_calls=[]),
    "3": APICallStatus(
        api_call_status=API_TOOL_CALL,
        api_calls=[
            APICall(api_name="incorrect_api", params={"incorrect_param": "incorrect"})
        ],
    ),
    "4": APICallStatus(
        api_call_status=API_TOOL_CALL,
        api_calls=[
            APICall(
                api_name="correct_api_1",
                params={
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            )
        ],
    ),
    "5": APICallStatus(
        api_call_status=API_TOOL_CALL,
        api_calls=[
            APICall(
                api_name="correct_api_1",
                params={
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            ),
            APICall(
                api_name="correct_api_2",
                params={"incorrect_param": "incorrect", "correct_param": "correct"},
            ),
        ],
    ),
    "6": APICallStatus(
        api_call_status=API_TOOL_CALL,
        api_calls=[
            APICall(
                api_name="correct_api_1",
                params={
                    "left_param": "right",
                    "other_param": "sad",
                    "correct_param": "correct",
                },
            ),
            APICall(
                api_name="correct_api_2",
                params={
                    "left_param": "right",
                    "other_param": "glad",
                    "correct_param": "omg",
                },
            ),
        ],
    ),
    "7": APICallStatus(
        api_call_status=API_TOOL_CALL,
        api_calls=[
            APICall(
                api_name="correct_api_1",
                params={
                    "left_param": "left",
                    "other_param": "happy",
                    "correct_param": "correct",
                },
            ),
            APICall(
                api_name="correct_api_2",
                params={
                    "left_param": "left",
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            ),
        ],
    ),
    "8": APICallStatus(
        api_call_status=API_TOOL_CALL,
        api_calls=[
            APICall(
                api_name="wrong_api_1",
                params={
                    "left_param": "left",
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            ),
            APICall(
                api_name="correct_api_1",
                params={
                    "left_param": "left",
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            ),
            APICall(
                api_name="correct_api_2",
                params={
                    "left_param": "left",
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            ),
        ],
    ),
    "9": APICallStatus(
        api_call_status=API_CAN_NOT_ANSWER,
        api_calls=[],
    ),
    "10": APICallStatus(
        api_call_status=API_REQUEST_FOR_INFO,
        api_calls=None,
    ),
    "11": APICallStatus(
        api_call_status=API_REQUEST_FOR_INFO,
        api_calls=[
            APICall(
                api_name="correct_api_1",
                params={
                    "left_param": "left",
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            )
        ],
    ),
    "12": APICallStatus(
        api_call_status=API_CAN_NOT_ANSWER,
        api_calls=[
            APICall(
                api_name="correct_api_1",
                params={
                    "left_param": "left",
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            )
        ],
    ),
}

truth_record = Record(
    id="1",
    data_set="",
    output="",
    pre_api=[],
    post_api=[],
    ending=[],
    conversation=[],
    api_def=[],
    api_call=APICallStatus(
        api_call_status=API_TOOL_CALL,
        api_calls=[
            APICall(
                api_name="correct_api_1",
                params={
                    "left_param": "left",
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            ),
            APICall(
                api_name="correct_api_2",
                params={
                    "left_param": "left",
                    "other_param": "glad",
                    "correct_param": "correct",
                },
            ),
        ],
    ),
)

truth_record_1 = Record(
    id="1",
    data_set="",
    output="",
    pre_api=[],
    post_api=[],
    ending=[],
    conversation=[],
    api_def=[],
    api_call=APICallStatus(
        api_call_status=API_CAN_NOT_ANSWER,
        api_calls=[],
    ),
)
truth = {
    "1": truth_record,
    "2": truth_record,
    "3": truth_record,
    "4": truth_record,
    "5": truth_record,
    "6": truth_record,
    "7": truth_record,
    "8": truth_record,
    "9": truth_record_1,
    "10": truth_record_1,
    "x": truth_record,
    "11": truth_record_1,
    "12": truth_record_1,
}


class EvaluationTestCases(unittest.TestCase):
    def test_precision_recall_f1(self):
        p = precision(0, 1)
        self.assertEqual(p, 0)
        p = precision(1, 1)
        self.assertEqual(p, 1.0)
        p = precision(1, 3)
        self.assertAlmostEqual(p, 0.33333333)
        p = precision(1, 0)
        self.assertEqual(p, 0.0)

        r = recall(0, 1)
        self.assertEqual(r, 0)
        r = recall(1, 1)
        self.assertEqual(r, 1.0)
        r = recall(1, 3)
        self.assertAlmostEqual(r, 0.33333333)
        r = recall(1, 0)
        self.assertEqual(r, 0.0)

        p = random.random()
        r = random.random()
        self.assertEqual(f1(p, r), 2 * p * r / (p + r))
        self.assertEqual(f1(0, 0), 0)

    def test_compute_metrics(self):
        api_p, api_r, api_f1, param_p, param_r, param_f1 = compute_metrics(eval_results)
        self.assertEqual(api_p.tolist(), [0, 0, 0.5, 1, 1, 1])
        self.assertEqual(api_r.tolist(), [0, 0, 0.5, 1, 1, 1])
        self.assertEqual(api_f1.tolist(), [0, 0, 0.5, 1, 1, 1])

        self.assertEqual(param_p.tolist(), [0, 0, 0.5, 0.5, 0.4, 1])
        self.assertEqual(param_r.tolist(), [0, 0, 0.25, 0.25, 0.5, 1])
        npt.assert_array_almost_equal(
            param_f1, np.array([0, 0, 0.33333333333, 0.33333333333, 0.44444444444, 1])
        )

    def test_calculate_sentence_similarity(self):
        model = SentenceTransformer(model_name_or_path=sf_model_name)
        score = calculate_sentence_similarity("abc", "abc", model)
        self.assertAlmostEqual(score, 1, places=5)
        score = calculate_sentence_similarity("abc", "", model)
        self.assertLess(score, 0.5)
        score = calculate_sentence_similarity("happy", "glad", model)
        self.assertGreater(score, 0.5)

        matched, _ = match_value("abc", "", 0, model)
        self.assertTrue(matched)
        matched, _ = match_value("abc", "", 0.9, model)
        self.assertFalse(matched)

    def test_eval_result(self):
        eval_count, results = eval_result(predictions, truth, 0.5)
        self.assertEqual(eval_count.predicted_api_num, 13)
        self.assertEqual(eval_count.predicted_param_num, 34)
        self.assertEqual(eval_count.correct_api_num, 7)
        self.assertEqual(eval_count.correct_param_num, 13)

        self.assertEqual(results[0].status, Status.NoPrediction)
        self.assertEqual(results[0].correct_api_num, 0)
        self.assertEqual(results[0].correct_param_num, 0)

        self.assertEqual(results[1].status, Status.NoPrediction)
        self.assertEqual(results[1].correct_api_num, 0)
        self.assertEqual(results[1].correct_param_num, 0)

        self.assertEqual(results[2].status, Status.IncorrectAPI)
        self.assertEqual(results[2].correct_api_num, 0)
        self.assertEqual(results[2].correct_param_num, 0)

        self.assertEqual(results[3].status, Status.IncorrectAPI)
        self.assertEqual(results[3].correct_api_num, 1)
        self.assertEqual(results[3].correct_param_num, 2)

        self.assertEqual(results[4].status, Status.IncorrectParam)
        self.assertEqual(results[4].correct_api_num, 2)
        self.assertEqual(results[4].correct_param_num, 3)

        self.assertEqual(results[5].status, Status.IncorrectParamValue)
        self.assertEqual(results[5].correct_api_num, 2)
        self.assertEqual(results[5].correct_param_num, 2)

        self.assertEqual(results[6].status, Status.Correct)
        self.assertEqual(results[6].correct_api_num, 2)
        self.assertEqual(results[6].correct_param_num, 6)

        self.assertEqual(results[7].status, Status.IncorrectAPI)
        self.assertEqual(results[7].correct_api_num, 0)
        self.assertEqual(results[7].correct_param_num, 0)

        self.assertEqual(results[8].status, Status.Correct)
        self.assertEqual(results[8].correct_api_num, 0)
        self.assertEqual(results[8].correct_param_num, 0)

        self.assertEqual(results[9].status, Status.IncorrectAPIStatus)
        self.assertEqual(results[9].correct_api_num, 0)
        self.assertEqual(results[9].correct_param_num, 0)

        self.assertEqual(results[10].status, Status.NoPrediction)
        self.assertEqual(results[10].correct_api_num, 0)
        self.assertEqual(results[10].correct_param_num, 0)

        self.assertEqual(results[11].status, Status.IncorrectAPIStatus)
        self.assertEqual(results[11].correct_api_num, 0)
        self.assertEqual(results[11].correct_param_num, 0)

        self.assertEqual(results[12].status, Status.Correct)
        self.assertEqual(results[12].correct_api_num, 0)
        self.assertEqual(results[12].correct_param_num, 0)

    def test_calculate_score(self):
        result_dict, reports, per_sample_data = calculate_score(truth, predictions, 0.5)

        self.assertEqual(result_dict["golden_api_num"], 18)
        self.assertEqual(result_dict["golden_param_num"], 54)
        self.assertEqual(result_dict["predicted_api_num"], 13)
        self.assertEqual(result_dict["predicted_param_num"], 34)
        self.assertEqual(result_dict["correct_api_num"], 7)
        self.assertEqual(result_dict["correct_param_num"], 13)

        self.assertAlmostEqual(result_dict["micro_averaged_P_api"], 0.5384, places=3)
        self.assertAlmostEqual(result_dict["micro_averaged_R_api"], 0.3888, places=3)
        self.assertAlmostEqual(result_dict["micro_averaged_F1_api"], 0.4516, places=3)
        self.assertAlmostEqual(result_dict["micro_averaged_P_param"], 0.3823, places=3)
        self.assertAlmostEqual(result_dict["micro_averaged_R_param"], 0.2407, places=3)
        self.assertAlmostEqual(result_dict["micro_averaged_F1_param"], 0.2954, places=3)

        self.assertAlmostEqual(result_dict["marco_averaged_P_api"], 0.3076, places=3)
        self.assertAlmostEqual(result_dict["marco_averaged_R_api"], 0.2692, places=3)
        self.assertAlmostEqual(result_dict["marco_averaged_F1_api"], 0.2820, places=3)
        self.assertAlmostEqual(result_dict["marco_averaged_P_param"], 0.2371, places=3)
        self.assertAlmostEqual(result_dict["marco_averaged_R_param"], 0.1666, places=3)
        self.assertAlmostEqual(result_dict["marco_averaged_F1_param"], 0.1871, places=3)

        self.assertTrue(
            np.allclose(
                per_sample_data.api_p,
                np.array(
                    [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
                ),
            )
        )
        self.assertTrue(
            np.allclose(
                per_sample_data.api_r,
                np.array(
                    [0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
                ),
            )
        )
        self.assertTrue(
            np.allclose(
                per_sample_data.api_f1,
                np.array(
                    [
                        0.0,
                        0.0,
                        0.0,
                        0.66666667,
                        1.0,
                        1.0,
                        1.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                    ]
                ),
            )
        )
        self.assertTrue(
            np.allclose(
                per_sample_data.param_p,
                np.array(
                    [
                        0.0,
                        0.0,
                        0.0,
                        1.0,
                        0.75,
                        0.33333333,
                        1.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                    ]
                ),
            )
        )
        self.assertTrue(
            np.allclose(
                per_sample_data.param_r,
                np.array(
                    [
                        0.0,
                        0.0,
                        0.0,
                        0.33333333,
                        0.5,
                        0.33333333,
                        1.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                    ]
                ),
            )
        )
        self.assertTrue(
            np.allclose(
                per_sample_data.param_f1,
                np.array(
                    [
                        0.0,
                        0.0,
                        0.0,
                        0.5,
                        0.6,
                        0.33333333,
                        1.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                        0.0,
                    ]
                ),
            )
        )


if __name__ == "__main__":
    unittest.main()
