# from sympy import *
# from sympy.parsing.latex import parse_latex
# import re
# from wrapt_timeout_decorator import *
# from typing import Callable, Dict,Any, List, Optional, Tuple
# import random
# import pandas as pd
# import copy
# import signal
# from openai import OpenAI

def _fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except:
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string

def _fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a = int(a)
        b = int(b)
        assert string == "{}/{}".format(a, b)
        new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
        return new_string
    except:
        return string

def _remove_right_units(string):
    # "\\text{ " only ever occurs (at least in the val set) when describing units
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        # assert len(splits) == 2
        return splits[0]
    else:
        return string

def _fix_sqrt(string):
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0] 
    for split in splits[1:]:
        if split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string
def _strip_string(string):
    # linebreaks  
    string = string.replace("\n", "")
    #print(string)

    # remove inverse spaces
    string = string.replace("\\!", "")
    #print(string)

    # replace \\ with \
    string = string.replace("\\\\", "\\")
    #print(string)

    # replace tfrac and dfrac with frac
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")
    #print(string)

    # remove \left and \right
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")
    #print(string)
    
    # Remove circ (degrees)
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")

    # remove dollar signs
    string = string.replace("\\$", "")
    
    # remove units (on the right)
    string = _remove_right_units(string)

    # remove percentage
    string = string.replace("\\%", "")
    string = string.replace("\%", "")

    # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    # if empty, return empty string
    if len(string) == 0:
        return string
    if string[0] == ".":
        string = "0" + string

    # to consider: get rid of e.g. "k = " or "q = " at beginning
    if len(string.split("=")) == 2:
        if len(string.split("=")[0]) <= 2:
            string = string.split("=")[1]

    # fix sqrt3 --> sqrt{3}
    string = _fix_sqrt(string)

    # remove spaces
    string = string.replace(" ", "")

    # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
    string = _fix_fracs(string)

    # manually change 0.5 --> \frac{1}{2}
    if string == "0.5":
        string = "\\frac{1}{2}"

    # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
    string = _fix_a_slash_b(string)

    return string

# def is_equiv(str1, str2, verbose=False):
#     if str1 is None and str2 is None:
#         print("WARNING: Both None")
#         return True
#     if str1 is None or str2 is None:
#         return False

#     try:
#         ss1 = _strip_string(str1)
#         ss2 = _strip_string(str2)
#         ss1 = parse_latex(ss1)
#         ss2 = parse_latex(ss2)
#         if verbose:
#             print(ss1, ss2)
#         return ss1 == ss2
#     except Exception as e:
#         return str1 == str2


# def extract_last_num(text: str) -> str:
#         match = re.search(r'boxed\{(.*?)\}', text)
#         if match:
#             text=match.group(1)
#         text = re.sub(r"(\d),(\d)", r"\1\2", text)  
#         res = re.findall(r"(-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)", text)
#         if len(res) > 0:
#             num_str = res[-1]
#             return num_str
#         else:
#             return "Error"


# def extract_final_answers_gsm8k(outputs):
#     # need to consider robostness
#     final_answers = []
#     for output in outputs:
#             final_answer = extract_last_num(output)
#             final_answers.append(final_answer)
#     return final_answers

# def most_frequent(ans_list):
#     # need to consider robostness
#     if len(ans_list)==0:
#         return "-1"
#     counter = 0
#     num = ans_list[0]

#     for i in ans_list:
#         current_frequency = ans_list.count(i)
#         if current_frequency > counter:
#             counter = current_frequency
#             num = i

#     return num

# def get_boxed_answer(solutions):
#     final_answers=[]
#     for solution in solutions:
#         tmp = solution.split('oxed{')[-1]
#         count = 0
#         answer = ''
#         for item in tmp:
#             if item == '}':
#                 if count == 0:
#                     break
#                 else:
#                     count -=1
#             elif item == '{':
#                 count += 1
#             answer += item
#         final_answers.append(answer)
#     return final_answers

# def get_last_uppercase(solutions):
#     final_answers=[]
#     for solution in solutions:
#         match = re.findall(r'[A-J]', solution)  

#         if match:
#             final_answers.append(match[0])
#         else:
#             final_answers.append("")
#     return final_answers

# def latex_to_python(latex_expr):

#     latex_expr = latex_expr.replace('\\+', '+') 
#     latex_expr = latex_expr.replace('\\-', '-')  
#     latex_expr = latex_expr.replace('\\times', '*')  
#     latex_expr = latex_expr.replace('\\cdot', '*')  
#     latex_expr = latex_expr.replace('\\div', '/') 
#     latex_expr = latex_expr.replace('\\frac', '/')  
#     return latex_expr


# def compute_correctness(ans_list,s,dataset):
#     is_correct=0
#     ans_list=[item for item in ans_list if item !="Error" and item!=""]
#     ans_list=[item.replace(" ","") for item in ans_list]
#     ans= most_frequent(ans_list)

#     if dataset in ["GSM8K","MGSM"]:
#         a = float(ans.replace(',',''))
#         if dataset=="MGSM":
#             s=float(s["solution"])
#         else:
#             s=float(s["solution"].replace(',',''))
#         if abs(s-a) < 1e-6:
#             is_correct=1
#     elif dataset=="MATH":
#         if is_equiv(s["solution"].replace(" ",""),ans):
#             is_correct=1
#     else:
#         matches = re.findall(r'\{(.*?)\}', ans)
#         if matches:
#             ans=matches[0]
#         if ans.replace(" ","").lower()==s["solution"].replace(" ","").lower():
#             is_correct=1
#     return is_correct
# def calculate_num_sampling(ans_list):
#     set_list=set(ans_list)
#     num_list=[]
#     for set_ans in set_list:
#         num_list.append(ans_list.count(set_ans))
#     return max(num_list)/len(ans_list)

import json
import re
import os
import torch
import random
import tiktoken
import numpy as np
from backoff import on_exception, expo
from math_verify import parse, verify
    
def write_json(obj, file_name):
    with open(file_name, "w") as f:
        json.dump(obj, f)

def read_json(file_name):
    with open(file_name, "r") as f:
        return json.load(f)

def get_number_choice(text):
    if not text:
        return "N/A"
    match = re.findall(r"answer is \((\d)\)", text)
    if match:
        return match[-1]
    else:
        match = re.findall(r"\((\d)\)", text)
        return match[-1] if match else "N/A"
    return "N/A"

def get_alphabet_choice(text, num_choice=4):
    choices = '|'.join([chr(65 + i) for i in range(num_choice)])
    if text:
        # First try to match with parentheses
        match = re.findall(f'([{choices}])\)', text)
        if not match:
            # If no match with parentheses, try without
            match = re.findall(f'([{choices}])', text)
    else:
        return "N/A"
    return match[-1] if match else "N/A"
    
def get_true_false(text):
    if not text:
        return "N/A"
    match = re.findall(r"(true|false)", text, re.IGNORECASE)
    return match[-1].lower() if match else "N/A"

def get_yes_no(text):
    if not text:
        return "N/A"
    match = re.findall(r"(yes|no)", text, re.IGNORECASE)
    return match[-1].lower() if match else "N/A"

def get_keywords(output):
    keywords = output.split("Keywords:")[-1].split(",")
    keywords = [i.strip().lower().replace(".", "") for i in keywords]
    return keywords

def get_token_count(string, encoding_name="gpt-3.5-turbo"):
    encoding = tiktoken.encoding_for_model(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

def is_math_equiv(ref, pred):
    # Test math equivalence of ref and pred, 
    # can also handle answer choices e.g., A vs. (A)
    try:
        if any([verify(parse(f"${ref}$"), parse(f"${pred}$")),
               verify(parse(ref), parse(pred)),
               verify(parse(ref), parse(pred.replace("\\(", "").replace("\\)", "")))]):
            return True
    except:
        return False    
    return False
    
def has_consensus(predictions):
    ref = predictions[0]
    for exp in predictions[1:]:
        if not ref == exp:
            return False
    return True
    
def last_boxed_only_string(string):
    if not string: return "N/A"
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return string

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx == None:
        retval = string
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval

def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return s

def parse_number(text):
    try:
        match = re.findall(r"\$?([0-9]+[\.,]?[0-9]*)", text)
        return float(match[-1].replace(",", "")) if match else "N/A"
    except:
        print(text)

def parse_boxed(s):
    if not s:
        return "N/A"
    s = last_boxed_only_string(s)
    s = remove_boxed(s)
    s = parse_number(s)
    return s