import re
import numpy as np
import unicodedata
from typing import List
import math
# references
# https://github.com/Leolty/tablellm/blob/aef85050f522900fd70920c2b7427a383e3066ab/utils/eval.py
# https://github.com/ppasupat/WikiTableQuestions/blob/master/evaluator.py



def stringify(x):
    if x is None:
        x = ''
    if not isinstance(x, str):
        x = str(x)
    return x


def normalize_number(x: str):
    def convert_match(match):
        num = match.group(0)
        if num != '.':
            return str(float(num))
        return num

    pattern = r'[+-]?[0-9]*[.]?[0-9]+' # any number
    return re.sub(pattern, convert_match, x)


def normalize(x):
    # Remove diacritics
    x = ''.join(c for c in unicodedata.normalize('NFKD', x) if unicodedata.category(c) != 'Mn')
    # Remove return
    x = re.sub(r'\n', ' ', x)
    # Remove star
    x = re.sub(r'\*', '', x)
    # Normalize or remove quotes and dashes
    x = re.sub(r"[‘’´`']", "", x)
    x = re.sub(r"[“”\"]", "", x)
    x = re.sub(r"[‐‑‒–—−]", "-", x)
    # Remove dash unless it is a negative sign
    x = re.sub(r'-(?!\d)', ' ', x)
    while True:
        old_x = x
        # Remove citations
        x = re.sub(r"((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$", "", x.strip())
        # Remove details in parenthesis
        x = re.sub(r"(?<!^)( \([^)]*\))*$", "", x.strip())
        # Remove outermost quotation mark
        x = re.sub(r'^"([^"]*)"$', r'\1', x.strip())
        if x == old_x:
            break
    # Remove final '.'
    if x and x[-1] == '.':
        x = x[:-1]
    # Convert to lowercase
    x = x.lower()
    # Remove commas between digits
    x = re.sub(r',', '', x)
    # Remove percent sign
    x = re.sub(r'%', '', x)
    # Remove everything before and including "answer:"
    while True:
        if 'answer:' in x:
            x = re.sub(r'^.*?answer:', '', x)
        else:
            break
    # Remove articles
    x = re.sub(r'\b(a|an|the)\b', ' ', x)
    # Remove unit
    x = re.sub(r'^([+-]?[0-9]*[.]?[0-9]+) \w+.*$', r'\1', x)
    # Remove $ dollar sign
    x = re.sub(r'[$]', '', x)
    # Normalize number
    x = normalize_number(x)
    # Collapse whitespaces
    x = re.sub(r'\s+', ' ', x).strip()
    return x



# old exact match without number tolerance
# def judge_exact_match(pred_answer: List[str], gt_answer: List[str]):
#     '''
#     pred_answer: List[str]
#     gt_answer: List[str]
#     '''
#     pred_answer = [normalize(stringify(pred)) for pred in pred_answer]
#     gt_answer = [normalize(stringify(gt)) for gt in gt_answer]
#     exact_match = set(pred_answer) == set(gt_answer)
#     return exact_match

# new exact match with number tolerance - important for symbolic reasoning
def safe_equal_str(a: str, b: str, rel_tol=1e-2, abs_tol=1e-2) -> bool:
    """
    Compare two strings that may represent numbers.
    If both can be converted to float, compare with tolerance.
    Otherwise, compare as normalized strings.
    """
    try:
        a_float = abs(float(a))
        b_float = abs(float(b))
        return math.isclose(a_float, b_float, rel_tol=rel_tol, abs_tol=abs_tol)
    except ValueError:
        return a == b


def judge_exact_match(pred_answer: List[str], gt_answer: List[str]) -> bool:
    """
    Judge exact match between two lists of answers, using numeric tolerance if applicable.

    Parameters:
    - pred_answer: list of predicted answers (strings)
    - gt_answer: list of ground truth answers (strings)

    Returns:
    - True if all predicted answers match the ground truths, considering float tolerance.
    """
    assert isinstance(pred_answer, list), f"prediction {pred_answer} must be a list"
    assert isinstance(gt_answer, list), f"groundtruth {gt_answer} must be a list"

    pred_answer = [normalize(stringify(pred)) for pred in pred_answer]
    gt_answer = [normalize(stringify(gt)) for gt in gt_answer]

    # Lengths must match
    if len(pred_answer) != len(gt_answer):
        return False

    # Element-wise match with tolerance
    for p, g in zip(sorted(pred_answer), sorted(gt_answer)):
        if not safe_equal_str(p, g):
            return False
    return True



if __name__ == '__main__':
    
    # print(judge_exact_match(['No', '3.2', 'yes', '123123'], ['yes', 'no', '123123', '3.2']))
    # print(judge_exact_match(['True'], ['True', 'False']))
    # print(judge_exact_match(['1.892617'], ['1.8926174496644295']))
    # print(judge_exact_match(['182.0'], ['182']))
    # print(judge_exact_match(['120,907'], ['120907.0']))
    # print(judge_exact_match(['71% for women, 75% for men, 71% total'], ['71.0']))
    print(judge_exact_match([2.21], ['$2.209954$']))
    # print(judge_exact_match(['7.1 percentage points'], ['7.1']))