import re

INVALID_ANS = "[INVALID_ANSWER]"
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")

def clean_answer(model_pred):
    # Convert the prediction to lowercase
    pred = model_pred.lower()

    # Extract the answer part using the "answer:" trigger
    answer_parts = pred.split("answer:")
    if len(answer_parts) > 1:
        # If "answer:" is found, take the part after it
        pred = answer_parts[1].strip()
    else:
        # If "answer:" is not found, consider the entire prediction as the answer
        pred = pred.strip()

    pred = [s for s in re.findall(r"-?\d+\.?\d*", pred)]

    if len(pred) == 0:
        return INVALID_ANS

    # Choose the first element in the list as the answer
    pred = pred[-1]

    # If the answer ends with a period, remove it
    if pred.endswith("."):
        pred = pred[:-1]

    return pred

def extract_answer_from_output(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return INVALID_ANS

def is_correct(model_pred, answer):
    model_answer = clean_answer(model_pred)
    gt_answer = extract_answer_from_output(answer)
    assert gt_answer != INVALID_ANS
    return model_answer == gt_answer