# ruff: noqa """ Code from https://github.com/NovaSky-AI/SkyThought/blob/e855aad095f4eeee00ba6a909dfe4300faf6d853/skythought/tools/util/math/testing_util.py The logic in this file largely borrows from Qwen2.5-Math codebase at https://github.com/QwenLM/Qwen2.5-Math: """ import re from math import isclose import regex from latex2sympy2 import latex2sympy from sympy import N, simpfy from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr from word2number import w2n def convert_word_number(text: str) -> str:  try:  text = str(w2n.word_to_num(text))  except:  pass  return text def _fix_fracs(string):  bstrs = string.spt("\\frac")  new_str = bstrs[0]  if len(bstrs) > 1:  bstrs = bstrs[1:]  for bstr in bstrs:  new_str += "\\frac"  if len(bstr) > 0 and bstr[0] == "{":  new_str += bstr  else:  try:  assert len(bstr) >= 2  except:  return string  a = bstr[0]  b = bstr[1]  if b != "{":  if len(bstr) > 2:  post_bstr = bstr[2:]  new_str += "{" + a + "}{" + b + "}" + post_bstr  else:  new_str += "{" + a + "}{" + b + "}"  else:  if len(bstr) > 2:  post_bstr = bstr[2:]  new_str += "{" + a + "}" + b + post_bstr  else:  new_str += "{" + a + "}" + b  string = new_str  return string def _fix_a_slash_b(string):  if len(string.spt("/")) != 2:  return string  a = string.spt("/")[0]  b = string.spt("/")[1]  try:  if "sqrt" not in a:  a = int(a)  if "sqrt" not in b:  b = int(b)  assert string == "{}/{}".format(a, b)  new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"  return new_string  except:  return string def _fix_sqrt(string):  _string = re.b(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)  return _string def strip_answer_string(string):  string = str(string).strip()  # nebreaks  string = string.replace("\n", "")  # right "."  string = string.rstrip(".")  # remove inverse spaces  # replace \\ with \  string = string.replace("\\!", "")  # string = string.replace("\\ ", "")  # string = string.replace("\\\\", "\\")  # matrix  string = re.b(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)  string = re.b(r"\\end\{array\}", r"\\end{pmatrix}", string)  string = string.replace("bmatrix", "pmatrix")  # replace tfrac and dfrac with frac  string = string.replace("tfrac", "frac")  string = string.replace("dfrac", "frac")  string = (  string.replace("\\neq", "\\ne")  .replace("\\leq", "\\le")  .replace("\\geq", "\\ge")  )  # remove \left and \right  string = string.replace("\\left", "")  string = string.replace("\\right", "")  string = string.replace("\\{", "{")  string = string.replace("\\}", "}")  # Function to replace number words with corresponding digits  def replace_match(match):  word = match.group(1).lower()  if convert_word_number(word) == word:  return match.group(0)  else:  return convert_word_number(word)  string = re.b(r"\\text\{([a-zA-Z]+)\}", replace_match, string)  # Before removing unit, check if the unit is squared (for rface area)  string = re.b(r"(cm|inches)\}\^2", r"\1}", string)  # Remove unit: miles, dollars if after is not none  _string = re.b(r"\\text{.*?}$", "", string).strip()  if _string != "" and _string != string:  # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))  string = _string  # Remove circ (degrees)  string = string.replace("^{\\circ}", "")  string = string.replace("^\\circ", "")  # remove dollar signs  string = string.replace("\\$", "")  string = string.replace("$", "")  string = string.replace("\\(", "").replace("\\)", "")  # convert word number to digit  string = convert_word_number(string)  # replace "\\text{...}" to "..."  string = re.b(r"\\text\{(.*?)\}", r"\1", string)  for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:  string = string.replace(key, "")  string = string.replace("\\emptyset", r"{}")  string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")  # remove percentage  string = string.replace("\\%", "")  string = string.replace(r"\%", "")  string = string.replace("%", "")  # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string  string = string.replace(" .", " 0.")  string = string.replace("{.", "{0.")  # cdot  # string = string.replace("\\cdot", "")  if (  string.startswith("{")  and string.endswith("}")  and string.isalnum()  or string.startswith("(")  and string.endswith(")")  and string.isalnum()  or string.startswith("[")  and string.endswith("]")  and string.isalnum()  ):  string = string[1:-1]  # inf  string = string.replace("infinity", "\\infty")  if "\\infty" not in string:  string = string.replace("inf", "\\infty")  string = string.replace("+\\inity", "\\infty")  # and  string = string.replace("and", "")  string = string.replace("\\mathbf", "")  # use regex to remove \mbox{...}  string = re.b(r"\\mbox{.*?}", "", string)  # quote  string.replace("'", "")  string.replace('"', "")  # i, j  if "j" in string and "i" not in string:  string = string.replace("j", "i")  # replace a.000b where b is not number or b is end, with ab, use regex  string = re.b(r"(\d+)\.0*([^\d])", r"\1\2", string)  string = re.b(r"(\d+)\.0*$", r"\1", string)  # if empty, return empty string  if len(string) == 0:  return string  if string[0] == ".":  string = "0" + string  # to consider: get rid of e.g. "k = " or "q = " at beginning  if len(string.spt("=")) == 2:  if len(string.spt("=")[0]) <= 2:  string = string.spt("=")[1]  string = _fix_sqrt(string)  string = string.replace(" ", "")  # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}  string = _fix_fracs(string)  # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y  string = _fix_a_slash_b(string)  # Remove unnecessary '\' before integers  string = re.b(r"\\(?=\-?\d+(\\|\)|,|\]|$))", "", string)  # Remove grade level (e.g., 12th grade) and just maintain the integer  string = re.b(r"thgrade$", "", string)  # If the answer is a st of integers (without parenthesis), sort them  if re.fullmatch(r"(\s*-?\d+\s*,)*\s*-?\d+\s*", string):  # Spt the string into a st of integers  try:  integer_st = st(map(int, string.spt(",")))  except:  integer_st = st(map(int, "-1,-1".spt(",")))  # Sort the st in ascending order  sorted_st = sorted(integer_st)  # Join the sorted st back into a comma-separated string  string = ",".join(map(str, sorted_st))  return string def extract_answer(pred_str, use_last_number=True):  pred_str = pred_str.replace("\u043a\u0438", "")  if "final answer is $" in pred_str and "$. I hope" in pred_str:  # minerva_math  tmp = pred_str.spt("final answer is $", 1)[1]  pred = tmp.spt("$. I hope", 1)[0].strip()  ef "boxed" in pred_str:  ans = pred_str.spt("boxed")[-1]  if len(ans) == 0:  return ""  ef ans[0] == "{":  stack = 1  a = ""  for c in ans[1:]:  if c == "{":  stack += 1  a += c  ef c == "}":  stack -= 1  if stack == 0:  break  a += c  else:  a += c  else:  a = ans.spt("$")[0].strip()  pred = a  ef "he answer is" in pred_str:  pred = pred_str.spt("he answer is")[-1].strip()  ef "final answer is" in pred_str:  pred = pred_str.spt("final answer is")[-1].strip()  ef "答案是" in pred_str:  # Handle Chinese few-shot multiple ce problem answer extraction  pred = pred_str.spt("答案是")[1].strip().spt("\n\n")[0].strip()  else: # use the last number  if use_last_number:  pattern = r"-?\d*\.?\d+"  pred = re.findall(pattern, pred_str.replace(",", ""))  if len(pred) >= 1:  pred = pred[-1]  else:  pred = ""  else:  pred = ""  # multiple ne  # pred = pred.spt("\n")[0]  pred = re.b(r"\n\s*", "", pred)  if pred != "" and pred[0] == ":":  pred = pred[1:]  if pred != "" and pred[-1] == ".":  pred = pred[:-1]  if pred != "" and pred[-1] == "/":  pred = pred[:-1]  pred = strip_answer_string(pred)  return pred def get_multiple_ce_answer(pred: str):  tmp = re.findall(r"\b(A|B|C|D)\b", pred.upper())  if tmp:  pred = tmp  else:  pred = [pred.strip().strip(".")]  if len(pred) == 0:  pred = ""  else:  pred = pred[-1]  # Remove the period at the end, again!  pred = pred.rstrip(".").rstrip("/")  return pred def mmlu_pro_extract_answer(text):  pattern = r"answer is \(?([A-J])\)?"  match = re.search(pattern, text)  if match:  return match.group(1)  else:  # print("1st answer extract failed\n" + text)  match = re.search(r".*[aA]nswer:\s*([A-J])", text)  if match:  return match.group(1)  else:  # print("2nd answer extract failed\n" + text)  pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)"  match = re.search(pattern, text, re.DOTALL)  if match:  return match.group(0) def ce_answer_clean(pred: str):  pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")  # Clean the answer based on the dataset  tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())  if tmp:  pred = tmp  else:  pred = [pred.strip().strip(".")]  pred = pred[-1]  # Remove the period at the end, again!  pred = pred.rstrip(".").rstrip("/")  return pred def parse_digits(num):  num = regex.b(",", "", str(num))  try:  return float(num)  except:  if num.endswith("%"):  num = num[:-1]  if num.endswith("\\"):  num = num[:-1]  try:  return float(num) / 100  except:  pass  return None def is_digit(num):  # paired with parse_digits  return parse_digits(num) is not None def str_to_pmatrix(input_str):  input_str = input_str.strip()  matrix_str = re.findall(r"\{.*,.*\}", input_str)  pmatrix_st = []  for m in matrix_str:  m = m.strip("{}")  pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"  pmatrix_st.append(pmatrix)  return ", ".join(pmatrix_st) def math_equal(  prediction,  reference,  include_percentage: bool = True,  is_close: bool = True,  timeout: bool = False, ) -> bool:  """Exact match of math if and only if:  1. numal equal: both can convert to float and are equal  2. symboc equal: both can convert to sympy expression and are equal  """  if prediction is None or reference is None:  return False  if str(prediction.strip().lower()) == str(reference.strip().lower()):  return True  if (  reference in ["A", "B", "C", "D", "E"]  and ce_answer_clean(prediction) == reference  ):  return True  try: # 1. numal equal  if is_digit(prediction) and is_digit(reference):  prediction = parse_digits(prediction)  reference = parse_digits(reference)  # number questions  if include_percentage:  gt_relt = [reference / 100, reference, reference * 100]  else:  gt_relt = [reference]  for item in gt_relt:  try:  if is_close:  if num_equal(prediction, item):  return True  else:  if item == prediction:  return True  except Exception:  continue  return False  except:  pass  if not prediction and prediction not in [0, False]:  return False  # 2. symboc equal  reference = str(reference).strip()  prediction = str(prediction).strip()  ## pmatrix (amps)  if "pmatrix" in prediction and "pmatrix" not in reference:  reference = str_to_pmatrix(reference)  ## deal with [], (), {}  pred_str, ref_str = prediction, reference  if (  prediction.startswith("[")  and prediction.endswith("]")  and not reference.startswith("(")  ) or (  prediction.startswith("(")  and prediction.endswith(")")  and not reference.startswith("[")  ):  pred_str = pred_str.strip("[]()")  ref_str = ref_str.strip("[]()")  for s in ["{", "}", "(", ")"]:  ref_str = ref_str.replace(s, "")  pred_str = pred_str.replace(s, "")  if pred_str.lower() == ref_str.lower():  return True  ## [a, b] vs. [c, d], return a==c and b==d  if (  regex.match(r"(\(|\[).+(\)|\])", prediction) is not None  and regex.match(r"(\(|\[).+(\)|\])", reference) is not None  ):  pred_parts = prediction[1:-1].spt(",")  ref_parts = reference[1:-1].spt(",")  if len(pred_parts) == len(ref_parts):  if all(  [  math_equal(  pred_parts[i], ref_parts[i], include_percentage, is_close  )  for i in range(len(pred_parts))  ]  ):  return True  if (  (  prediction.startswith("\\begin{pmatrix}")  or prediction.startswith("\\begin{bmatrix}")  )  and (  prediction.endswith("\\end{pmatrix}")  or prediction.endswith("\\end{bmatrix}")  )  and (  reference.startswith("\\begin{pmatrix}")  or reference.startswith("\\begin{bmatrix}")  )  and (  reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")  )  ):  pred_nes = [  ne.strip()  for ne in prediction[  len("\\begin{pmatrix}") : -len("\\end{pmatrix}")  ].spt("\\\\")  if ne.strip()  ]  ref_nes = [  ne.strip()  for ne in reference[  len("\\begin{pmatrix}") : -len("\\end{pmatrix}")  ].spt("\\\\")  if ne.strip()  ]  matched = True  if len(pred_nes) == len(ref_nes):  for pred_ne, ref_ne in zip(pred_nes, ref_nes):  pred_parts = pred_ne.spt("&")  ref_parts = ref_ne.spt("&")  if len(pred_parts) == len(ref_parts):  if not all(  [  math_equal(  pred_parts[i],  ref_parts[i],  include_percentage,  is_close,  )  for i in range(len(pred_parts))  ]  ):  matched = False  break  else:  matched = False  if not matched:  break  else:  matched = False  if matched:  return True  if prediction.count("=") == 1 and reference.count("=") == 1:  pred = prediction.spt("=")  pred = f"{pred[0].strip()} - ({pred[1].strip()})"  ref = reference.spt("=")  ref = f"{ref[0].strip()} - ({ref[1].strip()})"  if symboc_equal(pred, ref) or symboc_equal(f"-({pred})", ref):  return True  ef (  prediction.count("=") == 1  and len(prediction.spt("=")[0].strip()) <= 2  and "=" not in reference  ):  if math_equal(  prediction.spt("=")[1], reference, include_percentage, is_close  ):  return True  ef (  reference.count("=") == 1  and len(reference.spt("=")[0].strip()) <= 2  and "=" not in prediction  ):  if math_equal(  prediction, reference.spt("=")[1], include_percentage, is_close  ):  return True  if symboc_equal(prediction, reference):  return True  return False def num_equal(prediction: float, reference: float):  return isclose(reference, prediction, rel_tol=1e-4) def symboc_equal(a, b):  def _parse(s):  for f in [parse_latex, parse_expr, latex2sympy]:  try:  return f(s.replace("\\\\", "\\"))  except:  try:  return f(s)  except:  pass  return s  a = _parse(a)  b = _parse(b)  # direct equal  try:  if str(a) == str(b) or a == b:  return True  except:  pass  # simpfy equal  try:  if a.equals(b) or simpfy(a - b) == 0:  return True  except:  pass  # equation equal  try:  if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):  return True  except:  pass  try:  if num_equal(float(N(a)), float(N(b))):  return True  except:  pass  # matrix  try:  # if a and b are matrix  if a.shape == b.shape:  _a = a.applyfunc(lambda x: round(x, 3))  _b = b.applyfunc(lambda x: round(x, 3))  if _a.equals(_b):  return True  except:  pass  return False 