import regex
from copy import deepcopy
from evaluation.eval.eval_utils import math_equal
from evaluation.eval.ocwcourses_eval_utils import normalize_numeric, numeric_equality, normalize_symbolic_equation, SymbolicMathMixin
from sympy.parsing.latex import parse_latex

def is_correct(item, pred_key='prediction', prec=1e-3):
    pred = item[pred_key]
    ans = item['answer']
    if isinstance(pred, list) and isinstance(ans, list):
        pred_matched = set()
        ans_matched = set()
        for i in range(len(pred)):
            for j in range(len(ans)):
                item_cpy = deepcopy(item)
                item_cpy.update({
                    pred_key: pred[i],
                    'answer': ans[j]
                })
                if is_correct(item_cpy, pred_key=pred_key, prec=prec):
                    pred_matched.add(i)
                    ans_matched.add(j)
                    if item_cpy[pred_key] == '2,3,4':
                        print(item, flush=True)
                        print("wtf", flush=True)
        return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
    elif isinstance(pred, str) and isinstance(ans, str):
        if '\\cup' in pred and '\\cup' in ans:
            item = deepcopy(item)
            item.update({
                pred_key: pred.split('\\cup'),
                'answer': ans.split('\\cup'),
            })
            return is_correct(item, pred_key=pred_key, prec=prec)
        else:
            label = False

            # 用llama3有些样本输出会多一些\\
            pred = pred.replace("\\\\", "\\")

            try:
                label = abs(float(regex.sub(r',', '', str(pred))) - float(regex.sub(r',', '', str(ans)))) < prec
            except:
                pass

            # if ans == "0.5":
            #     import pdb; pdb.set_trace()

            label = label or (ans and pred == ans) or math_equal(pred, ans, include_percentage=False, timeout=True)

            #参考答案省略了百分号
            label = label or (len(pred) > 0 and pred[-1] == "%" and pred[:-1] == ans)

            #有些答案会包含具体量词
            squeeze_pred = ""
            for ch in pred:
                if not (ch >= "a" and ch <= "z") and not (ch >= "A" and ch <= "Z"):
                    squeeze_pred += ch
            if squeeze_pred != pred:
                squeeze_correct = math_equal(squeeze_pred, ans, include_percentage=False, timeout=True)
                label = label or squeeze_correct

            #单独处理LaTeX表达式
            try:
                sympy_pred_result = float(parse_latex(pred).evalf())
                sympy_ans_result = float(parse_latex(ans).evalf())
                label = label or abs(sympy_pred_result - sympy_ans_result) < prec
            except:
                pass

            return label
    else:
        print(item, flush=True)
        raise NotImplementedError()

def eval_math(item, pred_key='prediction', prec=1e-3):
    pred = item[pred_key]
    if pred_key == 'program_output' and isinstance(pred, str):
        pred = [pred]
    ans = item['answer']
    if isinstance(pred, list) and isinstance(ans, list):
        # for some questions in MATH, `reference` repeats answers
        _ans = []
        for a in ans:
            if a not in _ans:
                _ans.append(a)
        ans = _ans
        # some predictions for MATH questions also repeats answers
        _pred = []
        for a in pred:
            if a not in _pred:
                _pred.append(a)
        # some predictions mistakenly box non-answer strings
        pred = _pred[-len(ans):]

    item.update({
        pred_key: pred,
        'answer': ans
    })
    return is_correct(item, pred_key=pred_key, prec=prec)

def eval_last_single_answer(item, pred_key='prediction', prec=1e-3):
    for key in [pred_key, 'answer']:
        assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
    return is_correct(item, pred_key=pred_key, prec=prec)

def eval_agieval_gaokao_math_cloze(item, pred_key='prediction', prec=1e-3):
    if pred_key == 'program_output' and isinstance(item[pred_key], str):
        item[pred_key] = [item[pred_key]]
    for key in [pred_key, 'answer']:
        assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
    pred = item[pred_key]
    ans = item['answer']
    _pred = []
    for p in pred:
        p = p + ";"
        while p:
            left_brackets = 0
            for i in range(len(p)):
                if p[i] == ';' or (p[i] == ',' and left_brackets == 0):
                    _p, p = p[:i].strip(), p[i + 1:].strip()
                    if _p not in _pred:
                        _pred.append(_p)
                    break
                elif p[i] in '([{':
                    left_brackets += 1
                elif p[i] in ')]}':
                    left_brackets -= 1
    pred = _pred[-len(ans):]
    if len(pred) == len(ans):
        for p, a in zip(pred, ans):
            item.update({
                pred_key: p,
                'answer': a,
            })
            if not is_correct(item, pred_key=pred_key, prec=prec):
                return False
        return True
    else:
        return False

def eval_agieval_gaokao_mathqa(item, pred_key='prediction', prec=1e-3):
    if pred_key == 'program_output' and isinstance(item[pred_key], str):
        item[pred_key] = [item[pred_key]]
    pred_str = " ".join(item[pred_key])
    ans = item['answer']
    tag = None
    idx = -1
    for t in 'ABCD':
        if t in pred_str and pred_str.index(t) > idx:
            tag = t
            idx = pred_str.index(t)
    return tag == ans

def eval_math_sat(item, pred_key='prediction', prec=1e-3):
    for key in [pred_key, 'answer']:
        assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
    return item[pred_key].lower() == item['answer'].lower()

def eval_mmlu_stem(item, pred_key='prediction', prec=1e-3):
    return eval_math_sat(item, pred_key=pred_key, prec=prec)

def eval_ocwcourses(item, pred_key='prediction', prec=1e-3):
    INVALID_ANSWER = "[invalidanswer]"
    for key in [pred_key, 'answer']:
        assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
    pred = item[pred_key]
    ans = item['answer']

    try:
        float(ans)
        normalize_fn = normalize_numeric
        is_equiv = numeric_equality
        answer_type = "numeric"
    except ValueError:
        if "=" in ans:
            normalize_fn = normalize_symbolic_equation
            is_equiv = lambda x, y: x==y
            answer_type = "equation"
        else:
            normalize_fn = SymbolicMathMixin().normalize_tex
            is_equiv = SymbolicMathMixin().is_tex_equiv
            answer_type = "expression"

    correct_answer = normalize_fn(ans)

    unnormalized_answer = pred if pred else INVALID_ANSWER
    model_answer = normalize_fn(unnormalized_answer)

    if unnormalized_answer == INVALID_ANSWER:
        acc = 0
    elif model_answer == INVALID_ANSWER:
        acc = 0
    elif is_equiv(model_answer, correct_answer):
        acc = 1
    else:
        acc = 0

    return acc

def eval_minif2f_isabelle(item, pred_key='prediction', prec=1e-3):
    return True