import regex
from copy import deepcopy
from eval.eval_utils import math_equal
from eval.ocwcourses_eval_utils import (
    normalize_numeric,
    numeric_equality,
    normalize_symbolic_equation,
    SymbolicMathMixin,
)


def is_correct(item, pred_key="prediction", prec=1e-3):
    pred = item[pred_key]
    ans = item["answer"]
    if isinstance(pred, list) and isinstance(ans, list):
        pred_matched = set()
        ans_matched = set()
        for i in range(len(pred)):
            for j in range(len(ans)):
                item_cpy = deepcopy(item)
                item_cpy.update({pred_key: pred[i], "answer": ans[j]})
                if is_correct(item_cpy, pred_key=pred_key, prec=prec):
                    pred_matched.add(i)
                    ans_matched.add(j)
                    if item_cpy[pred_key] == "2,3,4":
                        print(item, flush=True)
                        print("wtf", flush=True)
        return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
    elif isinstance(pred, str) and isinstance(ans, str):
        if "\\cup" in pred and "\\cup" in ans:
            item = deepcopy(item)
            item.update(
                {
                    pred_key: pred.split("\\cup"),
                    "answer": ans.split("\\cup"),
                }
            )
            return is_correct(item, pred_key=pred_key, prec=prec)
        else:
            label = False
            try:
                label = (
                    abs(
                        float(regex.sub(r",", "", str(pred)))
                        - float(regex.sub(r",", "", str(ans)))
                    )
                    < prec
                )
            except:
                pass
            label = label or (ans and pred == ans) or math_equal(pred, ans)
            return label
    else:
        print(item, flush=True)
        raise NotImplementedError()


def eval_math(item, pred_key="prediction", prec=1e-3):
    pred = item[pred_key]
    if pred_key == "program_output" and isinstance(pred, str):
        pred = [pred]
    ans = item["answer"]
    if isinstance(pred, list) and isinstance(ans, list):
        # for some questions in MATH, `reference` repeats answers
        _ans = []
        for a in ans:
            if a not in _ans:
                _ans.append(a)
        ans = _ans
        # some predictions for MATH questions also repeats answers
        _pred = []
        for a in pred:
            if a not in _pred:
                _pred.append(a)
        # some predictions mistakenly box non-answer strings
        pred = _pred[-len(ans) :]

    item.update({pred_key: pred, "answer": ans})
    return is_correct(item, pred_key=pred_key, prec=prec)


def eval_last_single_answer(item, pred_key="prediction", prec=1e-3):
    for key in [pred_key, "answer"]:
        assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
    return is_correct(item, pred_key=pred_key, prec=prec)


def eval_agieval_gaokao_math_cloze(item, pred_key="prediction", prec=1e-3):
    if pred_key == "program_output" and isinstance(item[pred_key], str):
        item[pred_key] = [item[pred_key]]
    for key in [pred_key, "answer"]:
        assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
    pred = item[pred_key]
    ans = item["answer"]
    _pred = []
    for p in pred:
        p = p + ";"
        while p:
            left_brackets = 0
            for i in range(len(p)):
                if p[i] == ";" or (p[i] == "," and left_brackets == 0):
                    _p, p = p[:i].strip(), p[i + 1 :].strip()
                    if _p not in _pred:
                        _pred.append(_p)
                    break
                elif p[i] in "([{":
                    left_brackets += 1
                elif p[i] in ")]}":
                    left_brackets -= 1
    pred = _pred[-len(ans) :]
    if len(pred) == len(ans):
        for p, a in zip(pred, ans):
            item.update(
                {
                    pred_key: p,
                    "answer": a,
                }
            )
            if not is_correct(item, pred_key=pred_key, prec=prec):
                return False
        return True
    else:
        return False


def eval_agieval_gaokao_mathqa(item, pred_key="prediction", prec=1e-3):
    if pred_key == "program_output" and isinstance(item[pred_key], str):
        item[pred_key] = [item[pred_key]]
    pred_str = " ".join(item[pred_key])
    ans = item["answer"]
    tag = None
    idx = -1
    for t in "ABCD":
        if t in pred_str and pred_str.index(t) > idx:
            tag = t
            idx = pred_str.index(t)
    return tag == ans


def eval_math_sat(item, pred_key="prediction", prec=1e-3):
    for key in [pred_key, "answer"]:
        assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
    return item[pred_key].lower() == item["answer"].lower()


def eval_mmlu_stem(item, pred_key="prediction", prec=1e-3):
    return eval_math_sat(item, pred_key=pred_key, prec=prec)


def eval_ocwcourses(item, pred_key="prediction", prec=1e-3):
    INVALID_ANSWER = "[invalidanswer]"
    for key in [pred_key, "answer"]:
        assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
    pred = item[pred_key]
    ans = item["answer"]

    try:
        float(ans)
        normalize_fn = normalize_numeric
        is_equiv = numeric_equality
        answer_type = "numeric"
    except ValueError:
        if "=" in ans:
            normalize_fn = normalize_symbolic_equation
            is_equiv = lambda x, y: x == y
            answer_type = "equation"
        else:
            normalize_fn = SymbolicMathMixin().normalize_tex
            is_equiv = SymbolicMathMixin().is_tex_equiv
            answer_type = "expression"

    correct_answer = normalize_fn(ans)

    unnormalized_answer = pred if pred else INVALID_ANSWER
    model_answer = normalize_fn(unnormalized_answer)

    if unnormalized_answer == INVALID_ANSWER:
        acc = 0
    elif model_answer == INVALID_ANSWER:
        acc = 0
    elif is_equiv(model_answer, correct_answer):
        acc = 1
    else:
        acc = 0

    return acc


def eval_minif2f_isabelle(item, pred_key="prediction", prec=1e-3):
    return True
