
import os
import re
import backoff
import re
from pathlib import Path
import openai
from dotenv import load_dotenv


env_path = Path(__file__).parent.parent.parent / ".env"
load_dotenv(dotenv_path=env_path, override=True)

M = 1_000_000

ANSWER_REGEX = re.compile(r"-?\d+(?:,\d{3})*(?:\.\d+)?")


def extract_numeric_answer(text: str) -> str:
    
    matches = ANSWER_REGEX.findall(text.replace(",", ""))
    return matches[-1].lstrip("0") if matches else text.strip()


def backoff_handler(details):
    exc = details.get("exception")
    if exc:
        print(
            f"OpenAI - Retry {details['tries']} due to error: {exc}. "
            f"Waiting {details['wait']:0.1f}s..."
        )


costs_per_token = {
    "gpt-4.1-nano": {"input": 0.1 / M, "output": 0.4 / M},
    "gpt-4.1-mini": {"input": 0.4 / M, "output": 1.6 / M},
    "gpt-4.1": {"input": 2.0 / M, "output": 8.0 / M},
    "gpt-4o-mini": {"input": 0.15 / M, "output": 0.6 / M},
    "o4-mini": {"input": 1.1 / M, "output": 4.4 / M},
}


class MaxCallsExceededError(Exception):
    

    pass


def create_call_limited_query_llm(base_query_llm, max_calls=3):
    
    import threading

    thread_local = threading.local()

    def limited_query_llm(*args, **kwargs):
        
        if not hasattr(thread_local, "call_count"):
            thread_local.call_count = 0

        if thread_local.call_count >= max_calls:
            raise MaxCallsExceededError(
                f"Maximum number of LLM calls ({max_calls}) exceeded"
            )
        thread_local.call_count += 1
        return base_query_llm(*args, **kwargs)

    def reset_calls():
        thread_local.call_count = 0

    def get_call_count():
        return getattr(thread_local, "call_count", 0)

    
    limited_query_llm.reset_calls = reset_calls
    limited_query_llm.get_call_count = get_call_count

    return limited_query_llm


@backoff.on_exception(
    backoff.expo,
    (
        openai.APIConnectionError,
        openai.APIStatusError,
        openai.RateLimitError,
        openai.APITimeoutError,
    ),
    max_tries=20,
    max_value=20,
    on_backoff=backoff_handler,
)
def query_llm(prompt, system, temperature=0.0, model_name="gpt-4.1-nano"):
    
    
    
    
    
    client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    if system is not None:
        messages = [
            {"role": "system", "content": system},
            {"role": "user", "content": prompt},
        ]
    else:
        messages = [{"role": "user", "content": prompt}]

    if model_name == "o4-mini":
        temperature = 1.0

    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=temperature,
        
    )
    out_text = response.choices[0].message.content
    cost = (
        response.usage.prompt_tokens * costs_per_token[model_name]["input"]
        + response.usage.completion_tokens * costs_per_token[model_name]["output"]
    )
    return out_text, cost




def is_equiv(str1, str2, verbose=False):
    if str1 is None and str2 is None:
        print("WARNING: Both None")
        return True
    if str1 is None or str2 is None:
        return False

    try:
        ss1 = strip_string(str1)
        ss2 = strip_string(str2)
        if verbose:
            print(ss1, ss2)
        return ss1 == ss2
    except Exception:
        return str1 == str2


def clean_answer(s):
    
    s = s.replace("\\dfrac", "\\frac")
    s = s.replace("x \\in", "")

    
    s = re.sub(r"\\mathbf\s*{([^}]*)}", r"\1", s)
    s = re.sub(r"\\textbf\s*{([^}]*)}", r"\1", s)
    return s


def remove_boxed(s):
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[: len(left)] == left
        return s[len(left) :]

    left = "\\boxed{"
    if not s.startswith(left):
        return None

    assert s[-1] == "}"

    return clean_answer(s[len(left) : -1])


def last_boxed_only_string(string: str) -> str:
    
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
    if idx < 0:
        return ""

    
    brace_idx = string.find("{", idx)
    if brace_idx < 0:
        return ""  

    
    level = 0
    for i in range(brace_idx, len(string)):
        if string[i] == "{":
            level += 1
        elif string[i] == "}":
            level -= 1
            if level == 0:
                return string[idx : i + 1]

    return ""  


def fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except AssertionError:
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string


def fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a = int(a)
        b = int(b)
        assert string == "{}/{}".format(a, b)
        new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
        return new_string
    except AssertionError:
        return string


def remove_right_units(string):
    
    
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        assert len(splits) == 2
        return splits[0]
    else:
        return string


def fix_sqrt(string):
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0]
    for split in splits[1:]:
        if split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string


def strip_string(string):
    
    string = string.replace("\n", "")

    
    string = string.replace("\\!", "")

    
    string = string.replace("\\\\", "\\")

    
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")

    
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")

    
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")

    
    string = string.replace("\\$", "")

    
    string = remove_right_units(string)

    
    string = string.replace("\\%", "")
    string = string.replace("\%", "")  

    
    
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    
    if len(string) == 0:
        return string
    if string[0] == ".":
        string = "0" + string

    
    if len(string.split("=")) == 2:
        if len(string.split("=")[0]) <= 2:
            string = string.split("=")[1]

    
    string = fix_sqrt(string)

    
    string = string.replace(" ", "")

    
    
    
    string = fix_fracs(string)

    
    if string == "0.5":
        string = "\\frac{1}{2}"
    if string == "5.5":
        string = "\\frac{11}{2}"
    if "(x - 3)(x + 3)" in string:
        string = string.replace("(x - 3)(x + 3)", "(x+3)(x-3)")

    
    
    string = fix_a_slash_b(string)

    return string
