import torch
import re
import string
from nltk.metrics import f_measure

COMPILED_REGEX = re.compile(r"\\boxed\{(.*?)\}")

def normalize_text(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def exact_match(generation, reference):
    normalized_generation = normalize_text(generation)
    normalized_reference = normalize_text(reference)
    score = 0
    if normalized_generation == normalized_reference:
        score = 1
    return score


def f1_score(generation, reference):
    score = f_measure(
        set(normalize_text(reference).split()), set(normalize_text(generation).split())
        )
    if score is None:  # the answer may be en empty string after normalizing
        score = 0.0
    return score


def reward_func(queries, prompts, labels, **kwargs):
    """
    Reward function for calculating rewards of model outputs.

    Args:
        queries (torch.Tensor): Complete text sequences containing prompts and responses
        prompts (torch.Tensor): Input prompt sequences
        labels (torch.Tensor): Ground truth answer sequences
        **kwargs: Additional optional parameters

    Returns:
        dict: A dictionary containing the following key-value pairs:
            - rewards: Reward values used for calculating advantage function
            - scores: Reward values in range [0,1] used for dynamic filtering
    """
    # extract answer from \boxed{} in query
    answers = []
    for query in queries:
        matches = COMPILED_REGEX.findall(query)
        answer = matches[-1] if matches else ""
        print(f"end of query: {query[-100:]}")
        answers.append(answer)

    reward = []
    for answer, label in zip(answers, labels):
        answer = answer.strip()
        answer = " ".join(answer.split())
        answer = ", ".join([tmp.strip() for tmp in answer.split(",")])

        em = exact_match(answer, label)
        f1 = f1_score(answer, label)

        mix = em + f1
        print(f"answer: {answer}, label: {label}, em: {em}, f1: {f1}")
        reward.append(mix)

    reward = torch.tensor(reward).float()
    return {
        "rewards": reward,  # Rewards for advantage calculation
        "scores": reward,  # Scores for dynamic filtering (0-1 reward)
    }