from abc import ABC, abstractmethod
from typing import Any, Callable

import asyncio
import json
import os
import re
import shutil
import unicodedata
import ast
from copy import deepcopy

from tqdm import tqdm

from olym_gen.utils.utils import get_logger, retrieve_id_from_name, UNKNOWN_INDEX, get_generator_base
from olym_gen.generator.problem_proof_generator import LogProblemProofMixin
from olym_gen.generator.base_generator import SystemPromptMixin, GeneratorBase

logger = get_logger()

class UserAnswerPairMixin():
    """
    Mixin class for generating user prompts for problem proof pair generation.
    Now the user prompt is just used to compare two answers.
    """

    def _user_prompt(self, orig_answer: str, new_answer: str) -> str:
        """ Generate the user prompt for the answer pair.
        """
        return f"Answer 1: {orig_answer}\n\nAnswer 2: {new_answer}\n\n"

class ExtractAnswerMixin:
    """
    Mixin class for extracting the answer from the proof.
    """
    def solve_frac(self, answer: str) -> str:
        r"""
        Solve the fraction in the answer.
        This function is used to solve the fraction in the answer, like \frac{a}{b} -> (a)/(b).
        Now this function will handle nested fractions, like \frac{\frac{a}{b}}{c} -> ((a)/(b))/(c).
        The num or den must be wrapped by a brace. If no two brace following the \frac, it will return the original answer.
        """
        # use a regular expression to find all \frac{}{} patterns
        # and use a loop to handle nested fractions
        while r'\frac' in answer:
            match = re.search(r'\\frac', answer) # Notice re need another escape for '\'
            if not match:
                break
            
            start_index = match.end()
            
            # Helper to find the content of a braced group
            def find_open_brace(start_pos: int) -> int:
                """
                Find the position of the opening brace '{' starting from start_pos.
                If any non-whitespace character is found before the opening brace, or no opening brace is found,
                it will return -1.
                """
                for i in range(start_pos, len(answer)):
                    if answer[i] == '{':
                        return i
                    if answer[i].isspace():
                        continue
                    else:
                        logger.warning(f"Found non-whitespace character before opening brace in \\frac: `{answer[i]}` at position {i}")
                        return -1
                logger.warning("No opening brace found for numerator in \\frac")
                return -1
            
            def find_braced_content(start_pos: int) -> tuple[str | None, int]:
                """
                try to find the content of a braced group starting at start_pos.
                Returns the content and the end position of the group.
                If the start_pos does not a brace, raise a ValueError.
                If the braces are not matched, return None and -1.
                If the content is empty, return None and -1 to avoid empty fractions.
                """
                if start_pos >= len(answer) or answer[start_pos] != '{':
                    raise ValueError(f"Expected opening brace at position {start_pos}, found: {answer[start_pos] if start_pos < len(answer) else 'end of string'}")
                
                stack = 1
                end_pos = start_pos + 1
                while end_pos < len(answer) and stack > 0:
                    if answer[end_pos] == '{':
                        stack += 1
                    elif answer[end_pos] == '}':
                        stack -= 1
                    end_pos += 1
                
                if stack == 0:
                    return answer[start_pos + 1 : end_pos - 1], end_pos
                return None, -1

            # Find numerator
            num_start = find_open_brace(start_index)
            if num_start == -1:
                return answer  # Malformed, no opening brace for numerator
            numerator, num_end_pos = find_braced_content(num_start)
            if numerator is None:
                return answer

            # Find denominator

            den_start = find_open_brace(num_end_pos)
            denominator, den_end_pos = find_braced_content(den_start)
            if denominator is None:
                break # Malformed, unmatched braces for denominator

            # Replace the \frac{...}{...} with (...) / (...)
            # Here, we normalize the numerator and denominator
            new_num = self.normalize_latex_answer(numerator)
            new_den = self.normalize_latex_answer(denominator)
            if not new_num or not new_den:
                # To avoid empty fractions, we return the original answer
                return answer
            replacement = f"({new_num})/({new_den})"
            answer = answer[:match.start()] + replacement + answer[den_end_pos:]
            
        return answer
    
    def _normalize_number(self, match: re.Match) -> str:
        """
        Normalize a matched number string. The function handles:
        - Removes leading zeros from integers (e.g., 01 -> 1).
        - Removes trailing zeros from decimals (e.g., 2.0 -> 2, 1.20 -> 1.2).
        - Removes trailing decimal points (e.g., 2. -> 2).
        The input is expected to be a regex match object, which is a number string.
        """
        
        number_str = match.group(0)
        number_str = number_str.strip()
        number_str = re.sub(",", "", number_str)  # Remove commas if any
        
        # If it's an integer, remove leading zeros.
        if '.' not in number_str:
            return str(int(number_str))
        
        # For decimals, remove trailing zeros.
        number_str = number_str.rstrip('0')
        if number_str.endswith('.') and len(number_str) > 1:
            number_str = number_str[:-1]
        
        # because we have already removed trailing zeros, we need to check if the number is an integer again.
        if '.' not in number_str:
            return str(int(number_str))
        
        # For decimals, remove trailing zeros and then a trailing dot if it exists.
        number_str = number_str.lstrip('0')
        
        if not number_str or number_str == '.':
            return '0'
        
        return number_str
    
    def match_and_normalize_number(self, answer: str) -> str:
        # process any numbers, considering comma in numbers
        
        part = r'\d+(?:,\d+)*' # match all numbers with commas, but no leading or trailing commas
        # For example, "1,000", "1,00", "1000" are OK,
        # Given "1,000,", will match "1,000" and remove the trailing comma.
        # Given ",1,000", will match "1,000" and remove the leading comma.
        
        # We consider three cases:
        # 1. part.part (e.g., 1.2, 3.234,003)
        # 2. .part (e.g., .3, .300,003)
        # 3. part. (e.g., 1.2, 1.200,134, 1.)
        
        pattern = part + r'\.' + part + '|' + r'\.' + part + '|' + part + r'\.' + r'|' + part
        
        # re.sub can use a callable as the second argument to process each match
        answer = re.sub(pattern, self._normalize_number, answer)
        return answer
    
    # def _remove_outer_braces(self, answer: str) -> str:
    #     r"""
    #     if the first letter and the last letter is a corresponding brace, remove it.
    #     Examples:
    #         - "(x + y)" -> "x + y"
    #         - "[a + b]" -> "a + b"
    #         - "{1 + 2}" -> "1 + 2"
    #         - "\(\frac{a}{b}\)" -> "\frac{a}{b}"
    #         - "\[x^2 + y^2\]" -> "x^2 + y^2"
    #         - "\{z\}" -> "z"
    #         If the answer does not start and end with a brace, return the answer as is.
    #         If the answer starts and ends with a brace, but they are not a pair, return the answer as is.
    #         - (x + y) / (z + w) -> (x + y) / (z + w)
    #     """
    #     pairs = [('(', ')'), ('[', ']'), ('{', '}'), ('\\(', '\\)'), ('\\[', '\\]'), ('\\{', '\\}')]
        
    #     matched_pair = None
    #     for open_brace, close_brace in pairs:
    #         if answer.startswith(open_brace) and answer.endswith(close_brace):
    #             matched_pair = (open_brace, close_brace)
    #             break
        
    #     if not matched_pair:
    #         return answer

    #     open_brace, close_brace = matched_pair
    #     open_len, close_len = len(open_brace), len(close_brace)
        
    #     # We need to handle brace characters that might be part of other characters
    #     # This is a simplified check. A full-blown parser would be too complex here.
    #     # We scan for the full brace strings.
        
    #     # This logic is tricky with overlapping matches (e.g. if a brace is `\\` and we search for `\`).
    #     # The current pairs don't have this issue.
        
    #     # Let's just check the balance.
        
    #     # Simplified stack check for non-overlapping braces
    #     # This assumes that the brace strings don't overlap in a confusing way.
    #     # e.g. open_brace `\\(` and some other char `(`.
        
    #     # Let's do a proper stack check by iterating through the string
        
    #     # We need to find the balance of the outermost parentheses.
    #     # If the balance becomes 0 before the end, then the outer braces are not enclosing the whole expression.
        
    #     balance = 0
    #     # We check the substring without the potential outer braces
    #     core_str = answer[open_len:-close_len]
        
    #     # This is a simple character-by-character stack, which doesn't work for multi-character braces.
    #     # A regex-based approach or a more careful string search is needed.
        
    #     # Let's find all occurrences of open and close braces
    #     open_indices = [m.start() for m in re.finditer(re.escape(open_brace), answer)]
    #     close_indices = [m.start() for m in re.finditer(re.escape(close_brace), answer)]
        
    #     # The first open brace must be at the start
    #     if not open_indices or open_indices[0] != 0:
    #         return answer
            
    #     balance = 0
    #     all_indices = sorted(open_indices + close_indices)
        
    #     for index in all_indices:
    #         if index in open_indices:
    #             balance += 1
    #         elif index in close_indices:
    #             balance -= 1
            
    #         # If balance is zero before the end of the string (excluding the final brace)
    #         # it means the braces are not wrapping the whole expression.
    #         if balance == 0 and index < len(answer) - close_len:
    #             return answer
        
    #     # After iterating through all, the balance should be 0.
    #     # This check is implicitly handled by the fact that we found a matching pair at start/end
    #     # and the balance doesn't drop to 0 prematurely.
    #     if balance == 0:
    #         return answer[open_len:-close_len]
        
    #     return answer

    def normalize_latex_answer(self, answer: str) -> str:
        r"""
        Normalize a LaTeX math answer for comparison by removing trivial formatting differences.
        This function handles:
        - Remove all whitespace, bacause it is not important for mathematical expressions.
        - Replace common LaTeX commands with a more readable format. (e.g., \pi -> π, \frac{a}{b} -> a/b)
        - Remove some unnecessary leading or trailing zeros in numbers.
        - Remove outer braces.
        - NFKC Unicode normalization to handle different representations of the same character.
        """
        if not answer:
            return ""
        
        normalized = re.sub(r'\s+', '', answer)  # Remove all whitespace
        
        # Normalize common fraction representations        
        while True:
            new_normalized = self.solve_frac(normalized)
            if new_normalized == normalized:
                    break
            normalized = new_normalized
        
        normalized = self.match_and_normalize_number(normalized)
        
        # replace some common LaTeX commands with more readable format
        replacements = {
            '\\pi': 'π',
            '\\infty': '∞',
            '\\alpha': 'α',
            '\\beta': 'β',
            '\\gamma': 'γ',
            '\\cdot': '*',
            '\\times': '*',
            '\\div': '/',
            '\\pm': '±',
            '\\geq': '>=',
            '\\leq': '<=',
            '\\ge': '>=',
            '\\le': '<=',
            '\\neq': '!=',
            '\\ne': '!=',
            '\\approx': '≈',
            '\\dots': '...',
        }
        
        for latex_form, normalized_form in replacements.items():
            normalized = normalized.replace(latex_form, normalized_form)
                
        # Use word boundaries to avoid matching function names
        # Only remove braces around single alphanumeric characters that are NOT preceded by letters
        
        # For example, \log_{2}{x} -> \log_2{x}, \binom{x}{y} -> \binom{x}{y}
        
        while True:
            new_normalized = re.sub(r'(?<![a-zA-Z\}0-9])\{([a-zA-Z0-9])\}', r'\1', normalized)
            # Same for numeric expressions
            new_normalized = re.sub(r'(?<![a-zA-Z\}0-9])\{(\d+)\}', r'\1', new_normalized)
            if new_normalized == normalized:
                break
            normalized = new_normalized
            
        
        # Remove any '()', '\(\)', '[]', '\[\]', '{}', '\{\}' at the start and end
        while True:
            new_normalized = self._remove_outer_braces(normalized)
            if new_normalized == normalized:
                break
            normalized = new_normalized
        
        # Unicode normalization to handle different representations of the same character
        normalized = unicodedata.normalize('NFKC', normalized)
        
        return normalized
    
    def _remove_outer_braces(self, answer: str) -> str:
        r"""
        if the first letter and the last letter is a corresponding brace, remove it.
        Examples:
            - "(x + y)" -> "x + y"
            - "[a + b]" -> "a + b"
            - "{1 + 2}" -> "1 + 2"
            - "\(\frac{a}{b}\)" -> "\frac{a}{b}"
            - "\[x^2 + y^2\]" -> "x^2 + y^2"
            - "\{z\}" -> "z"
            If the answer does not start and end with a brace, return the answer as is.
            If the answer starts and ends with a brace, but they are not a pair, return the answer as is.
            - (x + y) / (z + w) -> (x + y) / (z + w)
        """
        pairs = [('(', ')'), ('[', ']'), ('{', '}'), ('\\(', '\\)'), ('\\[', '\\]'), ('\\{', '\\}')]
        
        matched_pair = None
        for open_brace, close_brace in pairs:
            if answer.startswith(open_brace) and answer.endswith(close_brace):
                matched_pair = (open_brace, close_brace)
                break
        
        if not matched_pair:
            return answer

        open_brace, close_brace = matched_pair
        open_len, close_len = len(open_brace), len(close_brace)
        
        # use a stack to check if they are a pair
        open_nums = 0
        # We need to handle brace characters that might be part of other characters
        # This is a simplified check. A full-blown parser would be too complex here.
        # We scan for the full brace strings.
        
        # This logic is tricky with overlapping matches (e.g. if a brace is `\\` and we search for `\`).
        # The current pairs don't have this issue.
        
        # Let's just check the balance.
        
        # Simplified stack check for non-overlapping braces
        # This assumes that the brace strings don't overlap in a confusing way.
        # e.g. open_brace `\\(` and some other char `(`.
        
        # Let's do a proper stack check by iterating through the string
        
        # We need to find the balance of the outermost parentheses.
        # If the balance becomes 0 before the end, then the outer braces are not enclosing the whole expression.
        
        balance = 0
        # We check the substring without the potential outer braces
        core_str = answer[open_len:-close_len]
        
        # This is a simple character-by-character stack, which doesn't work for multi-character braces.
        # A regex-based approach or a more careful string search is needed.
        
        # Let's find all occurrences of open and close braces
        open_indices = [m.start() for m in re.finditer(re.escape(open_brace), answer)]
        close_indices = [m.start() for m in re.finditer(re.escape(close_brace), answer)]
        
        # The first open brace must be at the start
        if not open_indices or open_indices[0] != 0:
            return answer
            
        balance = 0
        all_indices = sorted(open_indices + close_indices)
        
        for index in all_indices:
            if index in open_indices:
                balance += 1
            elif index in close_indices:
                balance -= 1
            
            # If balance is zero before the end of the string (excluding the final brace)
            # it means the braces are not wrapping the whole expression.
            if balance == 0 and index < len(answer) - close_len:
                return answer
        
        # After iterating through all, the balance should be 0.
        # This check is implicitly handled by the fact that we found a matching pair at start/end
        # and the balance doesn't drop to 0 prematurely.
        if balance == 0:
            return answer[open_len:-close_len]
        
        return answer
        
        


    def are_answers_essentially_equal(self, answer1: str, answer2: str) -> bool | None:
        """
        Compare two LaTeX-formatted math answers to determine if they are essentially equal. These function will first normalize the answers, then 
            (1) If both the normalized answers are `numerical expressions`, it will evaluate them and compare the numerical values. If equal, return True. If not, return False.
            (2) If both the normalized answers are `simple expressions`, that can be explain as a polynomial without function calling, it will compare them directly. If equal, return True. If not, return False.
            (3) Some special cases we can easily determine if they are equal, like commutative expressions, equivalent fractions, etc.
        
        Args:
            answer1: First answer string (LaTeX format)
            answer2: Second answer string (LaTeX format)
            
        Returns:
            bool | None: 
                - True if the answers are definitely equal
                - False if the answers are definitely not equal  
                - None if the equality cannot be determined with confidence
        """
        
        # Apply normalization
        norm1 = self.normalize_latex_answer(answer1)
        norm2 = self.normalize_latex_answer(answer2)
        
        if not answer1 or not answer2:
            return False  # Definitely not equal if one is empty
        
        if norm1 == norm2:
            return True  # Definitely equal after normalization
        
        # Evaluate numerical expressions if both are simple enough
        if (numer_equ := self._are_numerically_equal(norm1, norm2)) is not None:
            return numer_equ
            
        # If both are simple expressions, compare them directly
        if (expre_equ := self._are_expression_equal(norm1, norm2)) is not None:
            return expre_equ
        
        # If we can't determine with confidence, return None
        return None

    def _are_expression_equal(self, norm1: str, norm2: str) -> bool | None:
        """
        Check if two normalized answers are definitely equal through safe transformations.
        Returns True only when we are certain they represent the same mathematical value.
        """
        
        # Check for simple commutative expressions
        if self._are_commutatively_equal(norm1, norm2):
            return True
        
        # Check for equivalent fraction representations
        frac1 = self._parse_simple_fraction(norm1)
        frac2 = self._parse_simple_fraction(norm2)
        
        if frac1 and frac2:
            num1, den1 = frac1
            num2, den2 = frac2
            
            if (
                self.are_answers_essentially_equal(num1, num2) and
                self.are_answers_essentially_equal(den1, den2)
            ):
                return True
        
        return None

    def _are_commutatively_equal(self, expr1: str, expr2: str) -> bool:
        """
        Check if two expressions are equal under simple commutative operations.
        This function handles expressions involving addition and multiplication,
        respecting that subtraction and division are not commutative. It can
        handle mixed expressions like "a + b * c" vs "c * b + a".

        It works by parsing expressions into terms, normalizing them (e.g.,
        sorting factors in products), and then comparing the sorted lists of
        normalized terms.

        Note: This function does not handle the distributive property, so it
        cannot recognize "a * (b + c)" and "a*b + a*c" as equal.

        Examples:
            - "a*b" and "b*a" -> True
            - "a+b*c" and "c*b+a" -> True
            - "a-b" and "b-a" -> False (subtraction is not commutative)
        """
        # Remove all whitespace for consistent comparison
        expr1 = re.sub(r'\s+', '', expr1)
        expr2 = re.sub(r'\s+', '', expr2)
        
        if expr1 == expr2:
            return True
        
        # Check for simple addition commutation: a + b vs b + a
        # We can solve '-' because they can be considered as a special case of addition without '*' and '/'
        if '*' not in expr1 and '*' not in expr2:
            terms1 = self._parse_additive_terms(expr1)
            terms2 = self._parse_additive_terms(expr2)
            if terms1 and terms2 and len(terms1) == len(terms2) and sorted(terms1) == sorted(terms2):
                return True
            return False
        
        # Check for simple multiplication commutation: a * b vs b * a  
        if '*' in expr1 and '*' in expr2 and '+' not in expr1 and '+' not in expr2 and '-' not in expr1 and '-' not in expr2:
            factors1 = self._parse_multiplicative_factors(expr1)
            factors2 = self._parse_multiplicative_factors(expr2)
            if factors1 and factors2 and len(factors1) == len(factors2) and sorted(factors1) == sorted(factors2):
                return True
            return False

        if (
            ('*' in expr1 and '*' not in expr2) or
            ('*' not in expr1 and "*" in expr2) or
            ('/' in expr1 and '/' not in expr2) or
            ('/' not in expr1 and "/" in expr2)
            ):
            return False
        
        
        # Handle mixed addition and multiplication patterns like "a + b * c" vs "b * c + a"
        if self._are_mixed_commutative(expr1, expr2):
            return True
        
        return False

    def _are_mixed_commutative(self, expr1: str, expr2: str) -> bool:
        """
        Check for commutative equality in expressions with mixed addition and multiplication.
        
        This function tokenizes expressions into additive terms, where each term
        can be a simple variable/number or a product of factors. It then normalizes
        each term by sorting its multiplicative factors and compares the sorted
        lists of terms from both expressions.
        
        For example, for "a + c*b" and "b*c + a":
        1. Terms of expr1: ["a", "c*b"] -> Normalized terms: ["a", "b*c"]
        2. Terms of expr2: ["b*c", "a"] -> Normalized terms: ["a", "b*c"]
        3. Sorted normalized terms are equal, so it returns True.
        
        It avoids expressions with complex structures like parentheses or functions.
        """
        # Only handle simple cases without parentheses
        if any(char in expr1 + expr2 for char in ['(', ')', 'sin', 'cos', 'log', 'sqrt']):
            return False
        
        # Check if expressions are just rearrangements of addition terms
        # where some terms might be products
        terms1 = self._parse_mixed_terms(expr1)
        terms2 = self._parse_mixed_terms(expr2)
        
        if terms1 and terms2 and len(terms1) == len(terms2):
            # Normalize each term (sort factors within products)
            norm_terms1 = [self._normalize_term(term) for term in terms1]
            norm_terms2 = [self._normalize_term(term) for term in terms2]
            
            if sorted(norm_terms1) == sorted(norm_terms2):
                return True
        
        return False

    def _parse_mixed_terms(self, expr: str) -> list[str] | None:
        """
        Parse expressions with mixed addition and multiplication into a list of terms.
        
        This function first normalizes spacing around '+' and '-' operators to
        use them as delimiters. It then splits the expression into terms. It also
        handles unary '+' and '-' signs, attaching them to the subsequent term.
        
        Example:
            - "a+b*c-d" -> ["a", "b*c", "-d"]
            - "-x*y+z" -> ["-x*y", "z"]
            
        Returns a list of string terms, or None if the expression contains
        unsupported constructs.
        """
        # remove all whitespace
        expr = re.sub(r'\s+', '', expr)

        # Split by + and - first, preserving signs
        terms = []
        current_term = ""
        i = 0
        
        while i < len(expr):
            char = expr[i]
            if char in ['+', '-']:
                if i == 0:
                    current_term += char
                else:
                    if current_term.strip():
                        terms.append(current_term.strip())
                    current_term = char if char == '-' else ""
            else:
                current_term += char
            i += 1
        
        if current_term.strip():
            terms.append(current_term.strip())
        
        # Clean up terms
        cleaned_terms = []
        i = 0
        while i < len(terms):
            term = terms[i]
            if term == '+':
                # handle cases like "a + + b" or leading "+ b"
                if i + 1 < len(terms):
                    # merge with next term if it's not an operator
                    if terms[i+1] not in ['+', '-']:
                        cleaned_terms.append(terms[i+1])
                        i += 1
                # if '+' is trailing or followed by another operator, it's ignored
            elif term == '-':
                # handle cases like "a - b" or "a - - b"
                if i + 1 < len(terms):
                     if terms[i+1] not in ['+', '-']:
                        cleaned_terms.append('-' + terms[i+1])
                        i += 1
                # if '-' is trailing or followed by another operator, it's part of a syntax error we don't handle
            else:
                cleaned_terms.append(term)
            i += 1
        
        # Check if all terms are simple enough
        if cleaned_terms and all(self._is_simple_mixed_term(term) for term in cleaned_terms):
            return cleaned_terms
        
        return None

    def _is_simple_mixed_term(self, term: str) -> bool:
        """
        Check if a term is simple enough for mixed commutative comparison.
        
        A simple mixed term can be a simple term (as defined by `_is_simple_term`)
        or a product of simple terms (e.g., "2*x*y"). It can also be negative.
        
        Examples:
            - "x" -> True
            - "-10" -> True
            - "2*y" -> True
            - "a*b*c" -> True
            - "a+b" -> False (contains an additive operator)
        """
        term = term.strip()
        if not term:
            return False
        
        # Handle negative terms
        if term.startswith('-'):
            term = term[1:].strip()
        
        # Simple term
        if self._is_simple_term(term):
            return True
        
        # Product of simple terms
        if '*' in term:
            factors = term.split('*')
            factors = [f.strip() for f in factors if f.strip()]
            return all(self._is_simple_term(factor) for factor in factors)
        
        return False

    def _normalize_term(self, term: str) -> str:
        """
        Normalize a term by sorting factors within products for consistent comparison.
        
        If the term is a product, its factors are sorted alphabetically. This allows
        for comparing product terms regardless of factor order (e.g., "c*b" and "b*c"
        both become "b * c"). It also preserves a leading negative sign.
        
        Examples:
            - "c*b*a" -> "a * b * c"
            - "-z*y" -> "-y * z"
            - "x" -> "x"
        """
        term = term.strip()
        
        # Handle negative terms
        negative = term.startswith('-')
        if negative:
            term = term[1:].strip()
        
        # If it's a product, sort the factors
        if '*' in term:
            factors = term.split('*')
            factors = [f.strip() for f in factors if f.strip()]
            factors.sort()
            term = ' * '.join(factors)
        
        if negative:
            term = '-' + term
        
        return term

    def _parse_additive_terms(self, expr: str) -> list[str] | None:
        """
        Parse simple additive expressions like 'a + b' or 'x + 1 - y'.
        
        This function is a wrapper around `_parse_mixed_terms` but is intended
        for expressions that should only contain addition or subtraction. It returns
        None if the expression contains more complex operations like multiplication,
        division, or functions, ensuring it's used for simple additive cases.
        
        Returns a list of terms or None if the expression is too complex.
        """
        # Only handle expressions without parentheses or nested operations
        if any(op in expr for op in ['(', ')', '*', '/', '^', 'sin', 'cos', 'log', 'sqrt']):
            return None
        
        # Check if there are actual + or - operators (not just leading signs)
        has_operators = False
        for i, char in enumerate(expr):
            if char in ['+', '-'] and i > 0:
                has_operators = True
                break
        
        if not has_operators:
            return None  # Single term, not an additive expression
        
        # Split by + and -, keeping track of signs
        terms = []
        current_term = ""
        i = 0
        
        while i < len(expr):
            char = expr[i]
            if char in ['+', '-']:
                # If it's the first character, it's a sign for the first term
                if i == 0:
                    current_term += char
                else:
                    # It's an operator - save the current term and start new one
                    if current_term.strip():
                        terms.append(current_term.strip())
                    current_term = char if char == '-' else ""
            else:
                current_term += char
            i += 1
        
        if current_term.strip():
            terms.append(current_term.strip())
        
        # Clean up terms - remove leading + signs and normalize spaces
        cleaned_terms = []
        for term in terms:
            if term.startswith('+'):
                cleaned_terms.append(term[1:].strip())
            else:
                # Also clean up spaces in negative terms
                cleaned_terms.append(term.strip().replace('- ', '-'))
        
        # Only return if all terms are simple and we have multiple terms
        if len(cleaned_terms) >= 2 and all(self._is_simple_term(term) for term in cleaned_terms):
            return cleaned_terms
        
        return None

    def _parse_multiplicative_factors(self, expr: str) -> list[str] | None:
        """
        Parse simple multiplicative expressions like 'a * b' or '2 * x * y'.
        
        This function splits a string by the '*' operator to get a list of
        multiplicative factors. It returns None if the expression contains other
        operations (like '+', '-', '/') or functions, ensuring it only processes
        purely multiplicative expressions.
        
        Returns a list of factors or None if the expression is too complex.
        """
        # Handle empty or None input
        if not expr or not expr.strip():
            return None
            
        # Only handle expressions without parentheses or complex operations
        if any(op in expr for op in ['(', ')', '+', '-', '/', '^', 'sin', 'cos', 'log', 'sqrt']):
            return None
        
        factors = expr.split('*')
        factors = [f.strip() for f in factors if f.strip()]
        
        # Need at least 2 factors for meaningful multiplication
        if len(factors) < 2:
            return None
        
        # Only return if all factors are simple
        if all(self._is_simple_term(factor) for factor in factors):
            return factors
        
        return None

    def _is_simple_term(self, term: str) -> bool:
        """
        Check if a term is a simple, indivisible unit for comparison.
        
        A simple term is typically a single variable (e.g., "x"), a number ("123"),
        a constant ("π"), or a number followed by a variable/constant ("2x", "3π").
        It does not contain operators like '+', '*', etc. It can be negative.
        
        Examples:
            - "x" -> True
            - "-10" -> True
            - "2y" -> True
            - "π" -> True
            - "a+b" -> False
            - "a*b" -> False
        """
        term = term.strip()
        if not term:
            return False
        
        # Handle negative terms
        if term.startswith('-'):
            term = term[1:].strip()  # Strip again after removing minus sign
        
        # Simple number
        try:
            float(term)  # Try to convert to float
            return True
        except ValueError:
            pass
        
        # Simple variable or constant
        if re.match(r'^[a-zA-Zπ∞]+$', term):
            return True
        
        # Simple coefficient with variable (like 2x, 3π)
        if re.match(r'^\d+[a-zA-Zπ∞]+$', term):
            return True
        
        return False

    def _are_numerically_equal(self, expr1: str, expr2: str) -> bool | None:
        """
        Check if two expressions are numerically equal by evaluating them.
        First determine whether they are numerical expressions by `_evaluate_safe)expression`, otherwise return False.
        Then compare their evaluated values.
        Examples:
            - "2+3" and "5" -> True
            - "π+π" and "2π" -> True
            - "2.5" and "5/2" -> True
        """
        try:
            val1 = self._evaluate_safe_expression(expr1)
            val2 = self._evaluate_safe_expression(expr2)
                
            if val1 is not None and val2 is not None:
                    return abs(val1 - val2) < 1e-10
                
        except (ValueError, SyntaxError):
            logger.debug(f"Error evaluating expressions: '{expr1}' or '{expr2}'")
            return None
        
        return None

    def _evaluate_safe_expression(self, expr: str) -> float | None:
        """
        Safely evaluate a mathematical expression containing only numbers and basic operations.
        Returns None if the expression is not safe to evaluate.
        """
        # Replace mathematical constants
        expr = expr.replace('π', str(3.141592653589793))
        expr = expr.replace('e', str(2.718281828459045))
        
        # Only allow safe characters: numbers, basic operators, parentheses, decimal points
        if not re.match(r'^[0-9+\-*/().\s]+$', expr):
            return None
        
        # Additional safety checks
        if any(dangerous in expr for dangerous in ['__', 'import', 'eval', 'exec']):
            return None
        
        try:
            # Use ast.literal_eval for simple expressions, or eval with restricted scope for basic math
            # First try to parse as a literal
            try:
                return float(ast.literal_eval(expr))
            except (ValueError, SyntaxError):
                # If that fails, try safe evaluation of basic math expressions
                # Create a restricted environment
                safe_dict = {
                    '__builtins__': {},
                    'abs': abs,
                    'max': max,
                    'min': min,
                    'pow': pow,
                    'round': round,
                }
                
                # Only evaluate if expression is simple enough
                if len(expr) > 100:  # Arbitrary length limit for safety
                    return None
                
                result = eval(expr, safe_dict, {})
                return float(result)
        except:
            return None

    def _is_simple_expression(self, expr: str) -> bool:
        """
        Check if an expression is simple (no complex operations).
        Simple expressions are single variables, numbers, or basic mathematical constants.
        """
        expr = re.sub(r'\s+', '', expr)  # Remove all whitespace
        if not expr:
            return True
        
        # Single variable or number
        if re.match(r'^[a-zA-Z0-9π∞]+$', expr):
            return True
        
        return False

    def _parse_simple_fraction(self, expr: str) -> tuple[str, str] | None:
        """
        Parse a simple fraction in the form a/b or (a)/(b).
        Returns (numerator, denominator) or None if not a simple fraction.
        Only accepts fractions without further fraction.
        """
        match = re.match(r'^([^/]+)/([^/]+)$', expr)
        if match:
            num, den = match.groups()
            return (num, den)
        return None

    def extract_boxed_answer(self, proof: str) -> str:
        """
        Extract the answer from the proof.
        """
        pattern = r'\\boxed\s*\{'
        matches = list(re.finditer(pattern, proof))
        if not matches:
            logger.warning("No \\boxed found in the proof.")
            return ""
        if len(matches) != 1:
            logger.debug(f"Found {len(matches)} \\boxed in the proof, returning the first one.")
        
        # 选择第一个 \boxed{ 开始的位置
        last_match = matches[0]
        start = last_match.end()
        stack = 1
        end = start
        while end < len(proof) and stack > 0:
            if proof[end] == '{':
                stack += 1
            elif proof[end] == '}':
                stack -= 1
            end += 1
        
        if stack == 0:
            return proof[start:end-1]
        else:
            # 如果大括号不匹配，返回空字符串或抛出异常
            logger.warning("Unmatched braces in the proof.")
            return ""

    def jsonl_to_json(self, jsonl_file: str, resume: bool = False, lines: int|None = None) -> None:
        """
        Convert a JSONL file to a directory of JSON files on the fly.
        """
        output_dir = os.path.abspath(os.path.join(jsonl_file, os.pardir, "data"))
        os.makedirs(output_dir, exist_ok=True)
        with open(jsonl_file, 'r', encoding='UTF-8') as f:
            for i,line in enumerate(f):
                json_obj = json.loads(line)
                output_file = f"{output_dir}/problem_{json_obj['problem_index']}_proof_{json_obj['proof_index']}_generate_0.json"

                if resume and os.path.exists(output_file):
                    data = json.load(open(output_file, 'r', encoding='UTF-8'))
                    if data['question'] == json_obj['question'] and data['new_solution'] == json_obj['new_solution']:
                        logger.debug(f"Skipping file {output_file} as it already exists and matches the input.")
                        continue

                with open(output_file, 'w', encoding='UTF-8') as out_f:
                    json.dump(json_obj, out_f, ensure_ascii=False, indent=4)

                if lines is not None and i + 1 >= lines:
                    break
        print(f"Converted {jsonl_file} to JSON files in {output_dir}.")

class LoadAnswerPairMixin():
    """
    Mixin class for loading answer pairs from a JSON file.
    """

    def _preprocess_answer_pair(self, data: dict[str, Any]) -> list[dict[str, Any]]:
        """
        Preprocess the answer pair data to ensure it has the required fields.
        """
        
        if not isinstance(data['new_solution'], list):
            data['new_solution'] = [data['new_solution']]
        data_list = []
        for i,new_solution in enumerate(data['new_solution']):
            new_data = deepcopy(data)
            new_data['new_solution'] = new_solution
            new_data['proof_index'] = i
            new_data['source'] = new_data['source'][i]
            data_list.append(new_data)
        # print(data_list)
        return data_list
    
    def _preprocess_answer_pair_check(self, data: dict[str, Any]) -> list[dict[str, Any]]:
        """
        Preprocess the answer pair data to ensure it has the required fields.
        """
        
        if not isinstance(data['new_solution'], list):
            data['new_solution'] = [data['new_solution']]
        data_list = []
        for i,new_solution in enumerate(data['new_solution']):
            new_data = deepcopy(data)
            new_data['new_solution'] = new_solution
            new_data['proof_index'] = i
            new_data['source'] = data['source'][i]
            new_data['check_result'] = data['check_result'][i]
            data_list.append(new_data)
        # print(data_list)
        return data_list

    def load_answer_pair(self, file: str, preprocess_func: Callable[[dict[str, Any]], dict[str, Any]] | None = None, keep_problem_index: bool = True) -> list[tuple[dict[str, Any],int]]:
        """
        Load the answer pairs from the given JSONL file.
        """
        preprocess_func = preprocess_func if preprocess_func else self._preprocess_answer_pair
        with open(file, 'r', encoding='UTF-8') as f:
            data_list = [json.loads(line) for line in tqdm(f.readlines(), desc=f"Loading {file}")]
        data_list_processed = []
        for data in data_list:
            data_list_processed.extend(self._preprocess_answer_pair(data))
        logger.info(f"Loaded {len(data_list_processed)} answer pairs from {file}.")
        return [(data, i if not keep_problem_index else data.get('problem_index', i)) for i, data in enumerate(data_list_processed)]
    
    def join_check_and_answer(self, check_path: str, answer_path: str, output_path: str):
        """
        Join the answer check data into the solution check files.
        """
        os.makedirs(output_path, exist_ok=True)
        for file in tqdm(os.listdir(check_path), desc=f"Joining check and answer for files in {check_path}"):
            if not file.endswith('.json'):
                continue
            prob_id, proof_id, _ = retrieve_id_from_name(file)
            check_file = os.path.join(check_path, file)
            answer_file = os.path.join(answer_path, f"problem_{prob_id}_proof_{proof_id}_generate_0.json")
            output_file = os.path.join(output_path, file)
            if not os.path.exists(answer_file):
                logger.warning(f"Answer file {answer_file} does not exist, skipping.")
                continue

            with open(check_file, 'r', encoding='UTF-8') as f:
                check_data = json.load(f)
            with open(answer_file, 'r', encoding='UTF-8') as f:
                answer_data = json.load(f)

            check_data['orig_solution'] = answer_data['orig_solution']
            check_data['answer_correct'] = answer_data['answer_correct']

            with open(output_file, 'w', encoding='UTF-8') as out_f:
                json.dump(check_data, out_f, ensure_ascii=False, indent=4)

class AnswerPairGenerator(LogProblemProofMixin, ExtractAnswerMixin, UserAnswerPairMixin, LoadAnswerPairMixin, SystemPromptMixin):
    """
    Class for generators that compares two answers.
    """
    
    def __init__(
        self,
        provider: str = "dummy",
        model: str | None = None,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> None:
        # Use composition instead of inheritance for GeneratorBase
        self.generator_base: GeneratorBase = get_generator_base(provider, model, extra_model_paras)
        logger.debug(f"Initialized {self.__class__.__name__} with generator_base id: {id(self.generator_base)}")
    
    async def single_turn_request(self, *args, **kwargs):
        """Delegate to the composed generator_base"""
        return await self.generator_base.single_turn_request(*args, **kwargs)

    @property
    def system_prompt_file(self) -> str:
        return 'prompts/compare_answer_system_prompt.txt'

    def log_step_start(self, problem_index: int, proof_index: int) -> str:
        return f"Checking answer for problem {problem_index} and proof {proof_index}..."

    def log_step_finish(self, problem_index: int, proof_index: int) -> str:
        return f"Finished checking answer for problem {problem_index} and proof {proof_index}."

    def log_start(self, file: str, num_worker: int, num_returns: int) -> str:
        return f"Started to process the file {file} with {num_worker} workers and {num_returns} returns per problem. This will check the answers for the problems in the file."

    def log_finish(self, file: str, num_pairs: int, num_returns: int, save_path: str) -> str:
        return f'Finished processing the file {file}. Verified {num_pairs} answer pairs with {num_returns} checks each. The thinking process and checking results are saved to {save_path}.'

    async def _generate(
        self,
        orig_answer: str,
        new_answer: str,
        problem_index: int,
        proof_index: int,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
    ) -> list[tuple[str, str, bool] | None]:
        """
        Generate a response for the given problem and proof, checking the JSON format of the response.
        """

        logger.debug(self.log_step_start(problem_index, proof_index))

        system_prompt = self._system_prompt
        user_prompt = self._user_prompt(orig_answer, new_answer)
        return_list = await self.single_turn_request(
            system_prompt=system_prompt,
            user_prompt=user_prompt,
            shared_semaphore=shared_semaphore,
            num_returns=num_returns,
            max_tokens=8_192
        )

        extracted_return_list = []

        # check for the return list
        for generation in return_list:
            if generation is None:
                extracted_return_list.append(None)
                continue
            (thinking, solution) = generation
            result = self.extract_boxed_answer(solution)
            if result == 'True' or result == 'true':
                answer_correct = True
            elif result == 'False' or result == 'false':
                answer_correct = False
            else:
                extracted_return_list.append(None)
                logger.warning(f"Invalid answer in the generated result. Expected 'True' or 'False', got {result}.")
                continue
            extracted_return_list.append((thinking, solution, answer_correct))

        logger.debug(self.log_step_finish(problem_index, proof_index))

        return extracted_return_list

    async def process(
        self,
        load_path: str,
        lines: int | None,
        num_worker: int = 1,
        resume: bool = False,
        output_path: str | None = None,
        async_mode: bool = True,
        check: bool = False,
        batch_id: str | None = None,
    ):
        """
        Process the input file and generate response for the problem-proof pair.
        """
        self.batch_id = batch_id
        incorrect_data = []
        correct_data = []
        processed_data = {}
        
        if output_path is None:
            output_path = os.path.abspath(os.path.join(load_path, os.path.pardir,"processed_output", "processed_"+os.path.basename(load_path)))
        correct_path = os.path.abspath(os.path.join(output_path, os.pardir, "correct.jsonl"))
        incorrect_path = os.path.abspath(os.path.join(output_path, os.pardir, "incorrect.jsonl"))

        if check:
            data_list = self.load_answer_pair(load_path, preprocess_func=self._preprocess_answer_pair_check)
        else:
            data_list = self.load_answer_pair(load_path, preprocess_func=self._preprocess_answer_pair)
        
        if resume and os.path.exists(output_path):
            save_list = self.load_answer_pair(output_path, preprocess_func=lambda x: x)
            for data, index in save_list:
                processed_data[(index, data['proof_index'])] = data
                if data['answer_correct']:
                    correct_data.append((index, data['proof_index']))
                else:
                    incorrect_data.append((index, data['proof_index']))
            data_list = [data for data in data_list if (data[1], data[0]['proof_index']) not in processed_data]
        
        data_list = data_list if lines is None else data_list[:lines]

        semaphore = asyncio.Semaphore(num_worker)

        async def process_data(data: dict[str, Any], index: int):
            orig_proof = data.get("orig_solution", "Not Available")
            new_proof = data.get("new_solution", "Not Available")
            if 'answer' in data:
                orig_answer = data['answer']
            else:
                orig_answer = self.extract_boxed_answer(orig_proof)
            new_answer = self.extract_boxed_answer(new_proof)
            if orig_answer != "" and new_answer != "":
                equality_result = self.are_answers_essentially_equal(orig_answer, new_answer)
                if equality_result is True:
                    logger.debug(f"Skipping Problem {index} as the answers are definitely the same: '{orig_answer}' ≈ '{new_answer}'")
                    data['answer_correct'] = True
                    data['problem_index'] = index
                    processed_data[(index, data['proof_index'])] = data
                    correct_data.append((index, data['proof_index']))
                    return
                elif equality_result is False:
                    logger.debug(f"Problem {index} has definitely different answers: '{orig_answer}' ≠ '{new_answer}'")
                    data['answer_correct'] = False
                    data['problem_index'] = index
                    processed_data[(index, data['proof_index'])] = data
                    incorrect_data.append((index, data['proof_index']))
                    return
                # If equality_result is None, continue with LLM-based comparison

            if orig_answer == "" or new_answer == "":
                logger.debug(f"One of the answers is empty for Problem {index}: {orig_answer}, {new_answer}")
                return

            proof_index = data['proof_index']

            response = await self._generate(orig_answer, new_answer, index, proof_index, shared_semaphore=semaphore)

            res = response[0] if response else None
            if res is None:
                logger.warning(f"Response for problem {index} is None, skipping.")
                return
            (thinking, solution, answer_correct) = res
            data['answer_correct'] = answer_correct
            data['problem_index'] = index
            if not answer_correct:
                incorrect_data.append((index, proof_index))
            else:
                correct_data.append((index, proof_index))
            processed_data[(index, proof_index)] = data

        if not async_mode:
            for data, index in tqdm(data_list, desc='Processing files'):
                await process_data(data, index)
        else:
            tasks = [process_data(data, index) for data, index in data_list]
            for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc='Processing files'):
                await task
                
        count = len(data_list)
        valid_count = len(correct_data) + len(incorrect_data)
        logger.info(f"Processed {count} files, {valid_count} of them have valid responses.")
        logger.info(f"Correct rate : {(len(correct_data)) / valid_count * 100:.2f}%")

        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        os.makedirs(os.path.dirname(correct_path), exist_ok=True)
        os.makedirs(os.path.dirname(incorrect_path), exist_ok=True)
        with open(output_path, 'w', encoding='UTF-8') as f:
            for index, data in sorted(processed_data.items(), key=lambda x: x[0]):
                f.write(json.dumps(data, ensure_ascii=False) + '\n')
        with open(correct_path, 'w', encoding='UTF-8') as f:
            for index in sorted(correct_data):
                f.write(json.dumps(processed_data[index], ensure_ascii=False) + '\n')
        with open(incorrect_path, 'w', encoding='UTF-8') as f:
            for index in sorted(incorrect_data):
                f.write(json.dumps(processed_data[index], ensure_ascii=False) + '\n')

        return

async def main(sys_argv: list[str] | None = None):
    from argparse import ArgumentParser
    from olym_gen.generator.base_generator import common_parse_args
    base_parser = common_parse_args()
    check_parser = ArgumentParser(
        parents=[base_parser],
        description="Check solutions for problems using a language model."
    )
    check_parser.add_argument(
        '--check', action="store_true", help="Read in checks instead of solutions."
    )
    args = check_parser.parse_args(sys_argv)

    generator = AnswerPairGenerator(
        provider=args.provider,
        model=args.model
    )
    await generator.process(
        args.file,
        lines=args.lines,
        num_worker=args.num_worker,
        resume=args.resume,
        output_path=args.save_path,
        async_mode=not args.no_async,
        check=args.check,
        batch_id=args.batch_id,
    )
    logger.info("Finished processing all problems.")

if __name__ == "__main__":
    asyncio.run(main())