"""
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
- https://github.com/openai/prm800k
- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
- https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
"""
import os
import time
import signal
import functools
import concurrent.futures
from pysnooper import snoop
import re
import regex
import multiprocessing
from math import isclose
from typing import Union

from sympy import simplify, N
from sympy.parsing.sympy_parser import parse_expr
from sympy.parsing.latex import parse_latex
from latex2sympy2 import latex2sympy

def parse_digits(num):
    num = regex.sub(',', '', str(num))
    try:
        return float(num)
    except:
        if num.endswith('%'):
            num = num[:-1]
            if num.endswith('\\'):
                num = num[:-1]
            try:
                return float(num) / 100
            except:
                pass
    return None

def is_digit(num):
    # paired with parse_digits
    return parse_digits(num) is not None


def str_to_pmatrix(input_str):
    input_str = input_str.strip()
    matrix_str = re.findall(r'\{.*,.*\}', input_str)
    pmatrix_list = []

    for m in matrix_str:
        m = m.strip('{}')
        pmatrix = r'\begin{pmatrix}' + m.replace(',', '\\') + r'\end{pmatrix}'
        pmatrix_list.append(pmatrix)

    return ', '.join(pmatrix_list)


def math_equal(prediction: Union[bool, float, str],
                reference: Union[float, str],
                include_percentage: bool = True,
                is_close: bool = True,
                timeout: bool = False,
                ) -> bool:
    """
    Exact match of math if and only if:
    1. numerical equal: both can convert to float and are equal
    2. symbolic equal: both can convert to sympy expression and are equal
    """
    # print("Judge:", prediction, reference)
    if str(prediction) == str(reference):
        return True

    try: # 1. numerical equal
        if is_digit(prediction) and is_digit(reference):
            prediction = parse_digits(prediction)
            reference = parse_digits(reference)
            # number questions
            if include_percentage:
                gt_result = [reference / 100, reference, reference * 100]
            else:
                gt_result = [reference]
            for item in gt_result:
                try:
                    if is_close:
                        if numeric_equal(prediction, item):
                            return True
                    else:
                        if item == prediction:
                            return True
                except Exception:
                    continue
            return False
    except:
        pass

    if not prediction and prediction not in [0, False]:
        return False
    # print("try math_eval")

    # 2. symbolic equal
    reference = str(reference).strip()
    prediction = str(prediction).strip()

    ## pmatrix (amps)
    if "pmatrix" in prediction and not 'pmatrix' in reference:
        reference = str_to_pmatrix(reference)

    ## deal with [], (), {}
    pred_str, ref_str = prediction, reference
    if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
        (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
        pred_str = pred_str.strip("[]()")
        ref_str = ref_str.strip("[]()")
    for s in ['{', "}", "(", ")"]:
        ref_str = ref_str.replace(s, "")
        pred_str = pred_str.replace(s, "")
    if pred_str.lower() == ref_str.lower():
        return True

    ## [a, b] vs. [c, d], return a==c and b==d
    if regex.match(r'(\(|\[).+(\)|\])', prediction) is not None and regex.match(r'(\(|\[).+(\)|\])', reference) is not None:
        pred_parts = prediction[1:-1].split(",")
        ref_parts = reference[1:-1].split(",")
        if len(pred_parts) == len(ref_parts):
            if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
                return True
    if (prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}")) and (prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}")) and \
        (reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}")) and (reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")):
        pred_lines = [line.strip() for line in prediction[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
        ref_lines = [line.strip() for line in reference[len("\\begin{pmatrix}"): -len("\\end{pmatrix}")].split("\\\\") if line.strip()]
        matched = True
        if len(pred_lines) == len(ref_lines):
            for pred_line, ref_line in zip(pred_lines, ref_lines):
                pred_parts = pred_line.split("&")
                ref_parts = ref_line.split("&")
                if len(pred_parts) == len(ref_parts):
                    if not all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
                        matched = False
                        break
                else:
                    matched = False
                if not matched:
                    break
        else:
            matched = False
        if matched:
            return True

    if prediction.count('=') == 1 and reference.count('=') == 1:
        pred = prediction.split('=')
        pred = f"{pred[0].strip()} - ({pred[1].strip()})"
        ref = reference.split('=')
        ref = f"{ref[0].strip()} - ({ref[1].strip()})"
        if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
            return True
    elif prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference:
        if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
            return True
    elif reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction:
        if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
            return True

    # print("try final")
    # symbolic equal with sympy
    if timeout:
        if call_with_timeout(symbolic_equal_process, prediction, reference):
            return True
    else:
        if symbolic_equal(prediction, reference):
            return True

    return False


def math_equal_process(param):
    return math_equal(param[-2], param[-1])


def numeric_equal(prediction: float, reference: float):
    # Note that relative tolerance has significant impact 
    # on the result of the synthesized gsm_hard dataset
    # if reference.is_integer():
    #     return isclose(reference, round(prediction), abs_tol=1e-4)
    # else:
        # prediction = round(prediction, len(str(reference).split(".")[-1]))
    return isclose(reference, prediction, rel_tol=1e-4)


def symbolic_equal(a, b):
    def _parse(s):
        for f in [parse_latex, parse_expr, latex2sympy]:
            try:
                return f(s.replace("\\\\", "\\"))
            except:
                try:
                    return f(s)
                except:
                    pass
        return s
    a = _parse(a)
    b = _parse(b)

    # direct equal
    try:
        if str(a) == str(b) or a == b:
            return True
    except:
        pass

    # print("try simplify")
    # simplify equal
    try:
        if a.equals(b) or simplify(a-b) == 0:
            return True
    except:
        pass

    # print("try equation")
    # equation equal
    try:
        if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
            return True
    except:
        pass

    try:
        if numeric_equal(float(N(a)), float(N(b))):
            return True
    except:
       pass

    # matrix
    try:
        # if a and b are matrix
        if a.shape == b.shape:
            _a = a.applyfunc(lambda x: round(x, 3))
            _b = b.applyfunc(lambda x: round(x, 3))
            if _a.equals(_b):
                return True
    except:
        pass

    return False


def symbolic_equal_process(a, b, output_queue):
    result = symbolic_equal(a, b)
    output_queue.put(result)


# def timeout_wrapper(timeout_seconds):
#     def decorator(func):
#         @functools.wraps(func)
#         def wrapper(*args, **kwargs):
#             with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
#                 future = executor.submit(func, *args, **kwargs)
#                 try:
#                     return future.result(timeout=timeout_seconds)
#                 except concurrent.futures.TimeoutError:
#                     print(f"Function {func.__name__} timed out after {timeout_seconds} seconds")
#                     return False  # Or a default value/error code
#         return wrapper
#     return decorator


# @timeout_wrapper(timeout_seconds=1)
# def safe_symbolic_equal_process(a, b):
#     result = symbolic_equal(a, b)
#     return result


# def call_with_timeout(func, *args, timeout=1, **kwargs):
#     with Pool(processes=4) as pool:  # Only spawn 1 worker for this function
#         future = pool.apply_async(func, args=args, kwds=kwargs)
#         try:
#             return future.get(timeout=timeout)
#         except multiprocessing.TimeoutError:
#             return False


# def call_with_timeout(func, *args, timeout=1, **kwargs):
#     # Use a process pool with limited size
#     output_queue = multiprocessing.Queue()
#     process_args = args + (output_queue,)
#     with multiprocessing.Pool(processes=1) as pool:
#         async_result = pool.apply_async(func, args=process_args, kwds=kwargs)
#         try:
#             # return async_result.get(timeout=timeout)
#             async_result.get(timeout=timeout)
#             return output_queue.get()
#         except multiprocessing.TimeoutError:
#             return False



# def call_with_timeout(func, *args, timeout=1, **kwargs):
#     with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
#         future = executor.submit(func, *args, **kwargs)
#         try:
#             if future.result(timeout=timeout):  # Wait at most 1 second
#                 return future.result()
#         except concurrent.futures.TimeoutError:
#             return False  # Skip to next line



# TODO
MAX_PROCESSES = 1
semaphore = multiprocessing.Semaphore(MAX_PROCESSES)


def _func_wrapper(func, args, output_queue):
    """Wrapper that executes the function and puts result in the queue"""
    try:
        func(*args, output_queue=output_queue)
        # output_queue.put(result)
    except Exception as e:
        output_queue.put(("ERROR", str(e)))


def call_with_timeout(func, *args, timeout=0.2):
    output_queue = multiprocessing.Queue()
    
    # Create a wrapper function that doesn't include the queue in args
    with semaphore:
        process = multiprocessing.Process(
            target=_func_wrapper, 
            args=(func, args, output_queue)
        )
        
        process.daemon = True  # Make sure process exits when parent does
        process.start()
        
        start_time = time.time()
        
        # Wait for process to finish or timeout
        while process.is_alive() and time.time() - start_time < timeout:
            process.join(0.05)  # Short timeout to check status frequently
            
            # Try to get result from queue
            try:
                if not output_queue.empty():
                    result = output_queue.get(block=False)
                    process.terminate()  # Clean up the process
                    process.join(0.05)
                    
                    # Check if there was an error
                    if isinstance(result, tuple) and result[0] == "ERROR":
                        print(f"Error in process: {result[1]}")
                        return False
                    return result
            except queue.Empty:
                pass
        
        # If we got here, either process is still running after timeout
        # or it finished but put nothing in the queue
        if process.is_alive():
            print(f"Process timed out after {timeout} seconds")
            process.terminate()
            process.join(0.1)
            
            # In case terminate doesn't work, use SIGKILL as last resort
            if process.is_alive():
                os.kill(process.pid, signal.SIGKILL)
                
        return False


def call_with_timeout_original(func, *args, timeout=0.3, **kwargs):  # Changed *kwargs to **kwargs
    output_queue = multiprocessing.Queue()
    
    process_args = args + (output_queue,)
    
    with semaphore:  # Limit concurrent processes
        process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
        process.start()
        
        try:
            process.join(timeout)
            
            if process.is_alive():
                process.terminate()
                process.join()
                return False
            
            # Only try to get from queue if process completed successfully
            if process.exitcode == 0:
                return output_queue.get(block=False)  # Non-blocking get
            else:
                return False
                
        except queue.Empty:
            # Handle case where process completed but didn't put anything in queue
            return False


# @snoop()
# def call_with_timeout(func, *args, timeout=1, **kwargs):
#     output_queue = multiprocessing.Queue()
#     process_args = args + (output_queue,)

#     with semaphore:  # Limit concurrent processes
#         process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
#         process.start()
#         process.join(timeout)

#         if process.is_alive():
#             process.terminate()
#             process.join()
#             return False

#         return output_queue.get()


# def call_with_timeout(func, *args, timeout=1, **kwargs):
#     output_queue = multiprocessing.Queue()
#     process_args = args + (output_queue,)
#     process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
#     process.start()
#     process.join(timeout)

#     if process.is_alive():
#         process.terminate()
#         process.join()
#         return False

#     return output_queue.get()


def _test_math_equal():
    # print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
    # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
    # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
    # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True))
    # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True))

    # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}'
    # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})'

    # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}'
    # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}'

    # pred =  '-34x-45y+20z-100=0'
    # gt = '34x+45y-20z+100=0'

    # pred = '\\frac{100}{3}'
    # gt = '33.3'

    # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}'
    # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})'

    # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}'
    # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}'

    # pred = '(+5)(b+2)'
    # gt = '(a+5)(b+2)'

    # pred = '\\frac{1+\\sqrt{5}}{2}'
    # gt = '2'

    # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4'
    # pred = '1', gt = '1\\\\sqrt{19}'

    pred = '(0.6,2.6667]'
    gt = '(\\frac{3}{5},\\frac{8}{3}]'

    print(math_equal(pred, gt, timeout=True))


if __name__ == "__main__":
    _test_math_equal()

