import re
import json
import string

from datasets import load_dataset
from collections import Counter
from pathlib import Path

def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth, **kwargs):
    common = Counter(prediction) & Counter(ground_truth)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction)
    recall = 1.0 * num_same / len(ground_truth)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def qa_f1_score(prediction, ground_truth, **kwargs):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    return f1_score(prediction_tokens, ground_truth_tokens)

def drqa_metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    """Given a prediction and multiple valid answers, return the score of
    the best prediction-answer_n pair given a metric function.
    """
    # ground truth could be a string or a list of strings or a list of list of strings
    if isinstance(ground_truths, str):
        ground_truths = [ground_truths]
    elif isinstance(ground_truths[0], list):
        ground_truths = [ground_truth for ground_truths_list in ground_truths for ground_truth in ground_truths_list]

    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)

def ensure_path_exists(path):
    path_obj = Path(path)
    if path_obj.is_file():
        path_obj.parent.mkdir(parents=True, exist_ok=True)
    else:
        path_obj.mkdir(parents=True, exist_ok=True)


def extract_ans_from_response(raw_response, split_seg_list = ["The answer is: ", "The answer is:", "the answer is:", "答案是：", "答案是"]):
    # for Thinking model
    if "<Output>\n" in raw_response:
        raw_response = raw_response.split("<Output>\n")[-1].replace("</Output>", "").strip("\n")
    if "<Answer>\n" in raw_response:
        raw_response = raw_response.split("<Answer>\n")[-1].replace("</Answer>", "").strip("\n")
    # else
    potentinal_ans = raw_response
    for split_seg in split_seg_list:
        if split_seg in potentinal_ans:
            potentinal_ans = potentinal_ans.split(split_seg)[-1].split("\n\n")[0]
            return potentinal_ans
    return " "


######## data processing 
def load_custom_dataset(file_path):
    return load_dataset("json", data_files={"test": file_path},split="test")

def save_data(data, file_path):
    if file_path.endswith(".jsonl"):
        with open(file_path, 'w', encoding='utf-8') as f:
            for item in data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
    elif file_path.endswith('.json'):
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=4)
    else:
        raise ValueError("Please use `.json` or `.jsonl` ")