# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Add a new task:

TASK_NAME: {
    'metric_fn': the metric function with input (predictions: [str], references: [[str]]) to compute score.
}
"""
import re
import string
from collections import Counter

def extract_ans_from_response(raw_response, split_seg_list = ["The answer is: ", "The answer is:", "the answer is:"]):
    # for Thinking model
    if "<Output>\n" in raw_response:
        raw_response = raw_response.split("<Output>\n")[-1].replace("</Output>", "").strip("\n")
    if "<Answer>\n" in raw_response:
        raw_response = raw_response.split("<Answer>\n")[-1].replace("</Answer>", "").strip("\n")
    # else
    potentinal_ans = raw_response
    for split_seg in split_seg_list:
        if split_seg in potentinal_ans:
            potentinal_ans = potentinal_ans.split(split_seg)[-1].split("\n\n")[0]
            return potentinal_ans
    return " "


def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth, **kwargs):
    common = Counter(prediction) & Counter(ground_truth)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction)
    recall = 1.0 * num_same / len(ground_truth)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def qa_f1_score(prediction, ground_truth, **kwargs):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)


    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    return f1_score(prediction_tokens, ground_truth_tokens)



def string_match_part(preds, refs):
    preds = [extract_ans_from_response(pre) for pre in preds ]
    score = sum([max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) for pred, ref in zip(preds, refs)]) / len(preds) * 100
    return round(score, 2)

def qa_f1(preds, refs):
    preds = [extract_ans_from_response(pre) for pre in preds ]
    score = sum([max([qa_f1_score(pred, r) for r in ref]) for pred, ref in zip(preds, refs)]) / len(preds) * 100
    return round(score, 2)


def string_match_all(preds, refs):
    score = sum([sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) for pred, ref in zip(preds, refs)]) / len(preds) * 100
    return round(score, 2)
    

TASKS = {
    'niah': {
        'metric_fn': string_match_all,
    },
    'variable_tracking': {
        'metric_fn': string_match_all,
    },
    'common_words_extraction': {
        'metric_fn': string_match_all,
    },
    'freq_words_extraction': {
        'metric_fn': string_match_all
    },
    'qa': {
        'metric_fn': string_match_part,
        'metric_fn_2': qa_f1,
    },
}
