import re
from typing import Optional, Sequence
import copy
import numpy as np
import sympy
import sympy.parsing.latex as latex
import contextlib
import signal
import os
import errno
# pip install antlr4-python3-runtime==4.11
# pip install antlr4-tools
import antlr4
import math_verify


MATH_SUBSTITUTIONS_V0 = [
    ('an ', ''),
    ('a ', ''),
    ('.$', '$'),
    ('\\$', ''),
    (r'\ ', ''),
    (' ', ''),
    ('mbox', 'text'),
    (',\\text{and}', ','),
    ('\\text{and}', ','),
    ('\\text{m}', '\\text{}'),
]

MATH_REMOVED_EXPRESSIONS_V0 = [
    'square',
    'ways',
    'integers',
    'dollars',
    'mph',
    'inches',
    'ft',
    'hours',
    'km',
    'units',
    r'\ldots',
    'sue',
    'points',
    'feet',
    'minutes',
    'digits',
    'cents',
    'degrees',
    'cm',
    'gm',
    'pounds',
    'meters',
    'meals',
    'edges',
    'students',
    'childrentickets',
    'multiples',
    '\\text{s}',
    '\\text{.}',
    '\\text{\ns}',
    '\\text{}^2',
    '\\text{}^3',
    '\\text{\n}',
    '\\text{}',
    r'\mathrm{th}',
    r'^\circ',
    r'^{\circ}',
    r'\;',
    r',\!',
    '{,}',
    '"',
    '\\dots',
    '%',
]

MATH_REGEX_SUBSTITUTIONS_V0 = [
    # Extract answer that is in LaTeX math, bold, surrounded by a box, etc.
    (r'(.*?)(\$)(.*?)(\$)(.*)', '$\\3$'),
    (r'(\\text\{)(.*?)(\})', '\\2'),
    (r'(\\textbf\{)(.*?)(\})', '\\2'),
    (r'(\\overline\{)(.*?)(\})', '\\2'),
    (r'(\\boxed\{)(.*)(\})', '\\2'),
    # Normalize shorthand TeX:
    # \fracab -> \frac{a}{b}
    # \frac{abc}{bef} -> \frac{abc}{bef}
    # \fracabc -> \frac{a}{b}c
    # \sqrta -> \sqrt{a}
    # \sqrtab -> sqrt{a}b
    (r'(frac)([^{])(.)', 'frac{\\2}{\\3}'),
    (r'(sqrt)([^{])', 'sqrt{\\2}'),
]

MATH_REGEX_REMOVED_EXPRESSIONS_V0 = []

# Removing `a` can be incorrect, eg (a + b)
MATH_SUBSTITUTIONS_V1 = [s for s in MATH_SUBSTITUTIONS_V0 if s != ('a ', '')]

MATH_SUBSTITUTIONS_V1 += [
    ('\\dfrac', '\\frac'),
    ('\\tfrac', '\\frac'),
    ('\\cfrac', '\\frac'),
    ('interval(', '('),
    ('Interval(', '('),
    ('infinity', '\\infty'),
    ('Infinity', '\\infty'),
    ('**', '^'),
]

# removing ft if not isolated can be incorrect (eg in \infty)
MATH_REMOVED_EXPRESSIONS_V1 = [
    r for r in MATH_REMOVED_EXPRESSIONS_V0 if r != 'ft'
]

MATH_REGEX_REMOVED_EXPRESSIONS_V1 = [r'ft(\s|$|\.|\,)']

MATH_REGEX_SUBSTITUTIONS_V1 = [s for s in MATH_REGEX_SUBSTITUTIONS_V0]

# canonicalize special function escaping so that sympy recognizes
MATH_REGEX_SUBSTITUTIONS_V1 += [
    (r'(\text\{)(.*?)(\})', '\\2'),
    (r'(?<!\\)log', r'\\log'),
    (r'(?<!\\)ln', r'\\ln'),
    (r'(?<!\\)sin', r'\\sin'),
    (r'(?<!\\)cos', r'\\cos'),
    (r'(?<!\\)tan', r'\\tan'),
    (r'(?<!\\)cot', r'\\cot'),
    (r'(?<!\\)sec', r'\\sec'),
    (r'(?<!\\)csc', r'\\csc'),
    (r'(?<!\\)sqrt', r'\\sqrt'),
]


def extract_between(
        x: str,
        start: str = 'Final Answer: The final answer is ',
        end: str = '. I hope it is correct.',
        last_occurrence: bool = False,
) -> str:
    """Extracts the part of x between start and end, returning empty otherwise.

  Args:
    x: the string to extract from.
    start: the start delimiter.
    end: the end delimiter.
    last_occurrence: for strings that may have multiple occurrences of
      delimteters `start` and `end`, whether to extract the last occurrence. The
      default is to extract the first occurrence.

  Returns:
    the extracted string.
  """
    if start not in x:
        return ''
    if not last_occurrence:
        extracted = x.split(start)[1].split(end)[0].strip()
        return extracted
    else:
        extracted = x.split(start)[-1].strip().split(end)[0].strip()
        return extracted


def find_boxed_content(s, last_occurrence):
    pattern = r'\\boxed\{'
    all_matches = [m.end() for m in re.finditer(pattern, s)]
    if len(all_matches) == 0:
        return None
    if last_occurrence:
        start = all_matches[-1]
    else:
        start = all_matches[0]
    stack = 1
    i = start
    while i < len(s) and stack > 0:
        if s[i] == '{':
            stack += 1
        elif s[i] == '}':
            stack -= 1
        i += 1
    if stack == 0:
        return s[start:i - 1]  # Return the content inside the braces

    return None


def extract_between_and_with_boxes(
        x: str,
        last_occurrence: bool = False,
) -> str:
    """Extracts the boxed or delimited answer, returning empty string otherwise."""
    # deprecated extract_between
    # answer = extract_between(x, last_occurrence=last_occurrence)
    boxed_answer = find_boxed_content(x, last_occurrence=last_occurrence)
    if boxed_answer is not None:
        return boxed_answer
    else:
        return ''


def normalize_answer_math_helper(
        final_answer: str,
        substitutions: Sequence[tuple[str, str]],
        removed_expressions: Sequence[str],
        regex_substitutions: Sequence[tuple[str, str]],
        regex_removed_expressions: Sequence[str],
        remove_trailing_slash: bool,
        remove_slash: bool,
) -> str:
    """A helper function to facilitate writing math string normalization.

  This method composes the following normalization stages:
  - string replacements using pairs in substitutions.
  - removals using strings in removed_expressions.
  - regex replacements using pairs in regex_substitutions.
  - regex removals using strings in regex_removed_expressions.
  - removes dollar signs.
  - removes commas in numbers.
  - if specified removes trailing slashes.

  Used in normalize_answer_math and normalize_answer_math_v1.

  Args:
    final_answer: the text to nornalize.
    substitutions: replacement patterns.
    removed_expressions: patterns to remove.
    regex_substitutions: regex replaceent patterns.
    regex_removed_expressions: regex patterns to remove.
    remove_trailing_slash: remove trailing slash.

  Returns:
    text: the normalized text
  """

    final_answer = final_answer.split('=')[-1]

    for before, after in substitutions:
        final_answer = final_answer.replace(before, after)
    for expr in removed_expressions:
        final_answer = final_answer.replace(expr, '')
    for before, after in regex_substitutions:
        final_answer = re.sub(before, after, final_answer)
    for expr in regex_removed_expressions:
        final_answer = re.sub(expr, '', final_answer, flags=re.IGNORECASE)

    final_answer = final_answer.replace('$', '')
    final_answer = final_answer.replace('\\', '') if remove_slash else final_answer
    # Normalize 100,000 -> 100000
    if final_answer.replace(',', '').isdigit():
        final_answer = final_answer.replace(',', '')

    # Remove trailing \\
    if remove_trailing_slash:
        final_answer = final_answer.rstrip('\\')

    return final_answer


def normalize_answer_math(final_answer: str, remove_slash) -> str:
    r"""Second version of math answer normalization.

  Updated normalization logic to include:
    - function escaping -- escaping sin, log, ... so sympy parses correctly.
    - additional substitutions -- eg removing \\dfrac and \\cfrac.
    - remove trailing slashes.
    - catch some risky substitutions -- eg removing a -- fails in (a+b) and ft
      fails in \\infty.

  Args:
    final_answer: the text to nornalize.

  Returns:
    text: the normalized text
  """

    return normalize_answer_math_helper(
        final_answer,
        substitutions=MATH_SUBSTITUTIONS_V1,
        removed_expressions=MATH_REMOVED_EXPRESSIONS_V1,
        regex_substitutions=MATH_REGEX_SUBSTITUTIONS_V1,
        regex_removed_expressions=MATH_REGEX_REMOVED_EXPRESSIONS_V1,
        remove_trailing_slash=True,
        remove_slash=remove_slash,
    )


def split_answer_separator(text: str, separator: Optional[str] = None) -> str:
    """Takes the first part of answer given a separator."""
    text = text.strip()
    if separator is not None:
        text = text.split(separator)[0].strip()
    return text


def process_sample_gpqa(sample, few_shot_separator, extract_last_occurrence):
    """Get the final answer label (like (A)) from the model sample."""
    # Remove Markdown symbols
    if sample is None:
        return None
    sample = split_answer_separator(sample, few_shot_separator)
    sample = sample.replace("*", "")

    if "Final Answer:" not in sample:
        # print('no answer template', sample)
        return ""

    final_answer = sample.split("Final Answer:")[-1].strip()
    pattern = r"\([A-D]\)"
    # We return the last occurring label in the final answer stub if there are
    # multiple.
    answers = re.findall(pattern, final_answer)
    if not answers:
        # print('no final answer', final_answer)
        return ""
    if extract_last_occurrence:
        return answers[-1][1]
    else:
        return answers[0][1]


def equivalence_partition(iterable, relation):
    """Partitions a set of objects into equivalence classes.

  Note that this code assumes---but does not check---that the given
  relation is actually an equivalence relation, meaning that it is symmetric,
  reflexive, and transitive.

  Args:
      iterable: collection of objects to be partitioned
      relation: equivalence relation. I.e. relation(o1,o2) evaluates to True if
        and only if o1 and o2 are equivalent

  Returns:
      classes: A sequence of lists
  """
    classes = []
    for obj in iterable:  # for each object
        # find the class it is in
        if obj is None:
            classes.append([obj])
            continue
        found = False
        for cl in classes:
            if not cl[0]:  # modification
                continue
            if relation(cl[0], obj):  # is it equivalent to this class?
                cl.append(obj)
                found = True
                break
        if not found:  # it is in a new class
            classes.append([obj])
    return classes


def equivalence_partition_with_weights(iterable, iterable_weights, relation):
    """Partitions a set of objects into equivalence classes.

  Note that this code assumes---but does not check---that the given
  relation is actually an equivalence relation, meaning that it is symmetric,
  reflexive, and transitive.

  Args:
      iterable: collection of objects to be partitioned
      relation: equivalence relation. I.e. relation(o1,o2) evaluates to True if
        and only if o1 and o2 are equivalent

  Returns:
      classes: A sequence of lists
  """
    classes = []
    class_weights = []
    for obj, obj_weight in zip(iterable, iterable_weights):  # for each object
        # find the class it is in
        if obj is None:
            classes.append([obj])
            class_weights.append([obj_weight])
            continue
        found = False
        for class_index, cl in enumerate(classes):
            if not cl[0]:  # modification
                continue
            if relation(cl[0], obj):  # is it equivalent to this class?
                cl.append(obj)
                class_weights[class_index].append(obj_weight)
                found = True
                break
        if not found:  # it is in a new class
            classes.append([obj])
            class_weights.append([obj_weight])
    return classes, class_weights


def random_sample_prediction_partition(prediction_partition, num_samples):
    partition_indices = copy.deepcopy(prediction_partition)
    count = 0
    for i in range(len(prediction_partition)):
        for j in range(len(prediction_partition[i])):
            partition_indices[i][j] = count
            count += 1
    sampled_indices = np.random.choice(count, num_samples, replace=False)
    prediction_partition_copy = copy.deepcopy(prediction_partition)
    for i in range(len(prediction_partition)):
        for j in range(len(prediction_partition[i]) - 1, -1, -1):
            if partition_indices[i][j] not in sampled_indices:
                del prediction_partition_copy[i][j]
    for i in range(len(prediction_partition_copy) - 1, -1, -1):
        if len(prediction_partition_copy[i]) == 0:
            del prediction_partition_copy[i]
    return prediction_partition_copy


def compute_majority_vote_correct(processed_predictions, predictions_correctness, predictions_partition, strict_tie_breaking=True, partition_weights=None):
    max_weight = 0
    majority_answer = None
    all_majority_answers = []
    multiple_majority_answers = False  # are there >1 most-popular answers?
    for partition_index, equivalence_class in enumerate(predictions_partition):
        if not equivalence_class[0]:
            # Ignore empty strings, None, etc., corresponding to the model
            # failing to arrive at a final answer
            continue
        if partition_weights is None:
            current_partition_weight = len(equivalence_class)
        else:
            current_partition_weight = np.sum(partition_weights[partition_index])
        if current_partition_weight > max_weight:
            max_weight = current_partition_weight
            majority_answer = equivalence_class[0]
            multiple_majority_answers = False
            all_majority_answers = [majority_answer]
        elif current_partition_weight == max_weight:
            multiple_majority_answers = True
            all_majority_answers.append(equivalence_class[0])

    if multiple_majority_answers:
        # strict handling of draws (ties); see function docstring above.
        if strict_tie_breaking:
            return False
        else:
            majority_answer = np.random.choice(all_majority_answers)
    if not majority_answer:
        # No majority answer was found, which could occur if all answers are
        # None, empty string, etc.
        return False
    majority_idx = processed_predictions.index(majority_answer)
    return predictions_correctness[majority_idx]


def calculate_majority_vote_acc(example_data, num_samples):
    all_correctness = []
    for example in example_data:
        if type(example['processed_predictions']) is str:
            all_correctness.append(example['predictions_correctness'])
            continue
        processed_predictions = example['processed_predictions']
        predictions_correctness = example['predictions_correctness']
        predictions_partition = example['predictions_partition']
        assert num_samples <= len(processed_predictions)
        sampled_predictions_partition = random_sample_prediction_partition(predictions_partition, num_samples)
        sampled_majority_vote_correct = compute_majority_vote_correct(processed_predictions, predictions_correctness,
                                                                      sampled_predictions_partition)
        all_correctness.append(sampled_majority_vote_correct)
    return np.mean(all_correctness)


def calculate_majority_vote_acc_bootstrap(example_data, num_samples, N):
    majority_vote_accs = []
    for i in range(N):
        majority_vote_acc = calculate_majority_vote_acc(example_data, num_samples)
        majority_vote_accs.append(majority_vote_acc)
    return majority_vote_accs


def process_sample(sample, few_shot_separator=None, extract_last_occurrence=True, remove_slash=False):
    # few_shot_separator used to prevent model hallucinating new problems
    # extract last occurrence should be turned on in most cases
    if sample is None:
        return None
    sample = split_answer_separator(sample, few_shot_separator)
    sample = extract_between_and_with_boxes(sample, extract_last_occurrence)
    sample = normalize_answer_math(sample, remove_slash)
    return sample


def sample_match_strict(sample, reference):
    return sample == reference


class CustomTimeout(contextlib.ContextDecorator):
    """Class which allows for placing a timeout on functions.

  Here, we use it mainly to ensure that our STEMSolver evaluation does not hang
  due to Sympy simplification hanging.
  """

    def __init__(
            self,
            seconds,
            *,
            timeout_message=os.strerror(errno.ETIME),
            suppress_timeout_errors=False,
    ):
        self.seconds = int(seconds)
        self.timeout_message = timeout_message
        self.suppress = bool(suppress_timeout_errors)

    def _timeout_handler(self, unused_signum, unused_frame):
        signal.alarm(1)  # Raise an alarm again in case the exception is swallowed
        raise TimeoutError(self.timeout_message)

    def __enter__(self):
        self._orig_handler = signal.signal(signal.SIGALRM, self._timeout_handler)
        signal.alarm(self.seconds)

    def __exit__(self, exc_type, exc_val, exc_tb):
        signal.alarm(0)
        signal.signal(signal.SIGALRM, self._orig_handler)
        if self.suppress and exc_type is TimeoutError:
            return True


def to_numeric(s: str) -> float | None:
    """Converts s to number if possible; otherwise, returns None."""
    if '/' not in s:
        # not a fraction.
        try:
            return float(s)
        except ValueError:
            return None
    elif len(s.split('/')) == 2:
        num, den = s.split('/')
        try:
            return float(num) / float(den)
        except ValueError:
            return None
        except ZeroDivisionError:
            return None
    else:
        return None


def numerical_correctness(
        expr1: str,
        expr2: str,
        abstol: float = 1e-4,
        reltol: float = 1e-6,
) -> bool:
    """Checks if the expr1 and epxr2 strings represent the same number."""
    expr1_numeric = to_numeric(expr1)
    expr2_numeric = to_numeric(expr2)
    if expr1.isdigit() or expr2.isdigit():
        # No tolerance for integers.
        abstol = 0
        reltol = 0

    if expr1_numeric is None or expr2_numeric is None:
        # in this case, at least one string did not represent a number.
        return False
    else:
        difference = np.abs(expr1_numeric - expr2_numeric)
        min_val = np.min([np.abs(expr2_numeric), np.abs(expr1_numeric)]) + reltol
        abs_condition = difference <= abstol
        rel_condition = difference / min_val <= reltol
        result = abs_condition or rel_condition
        return bool(result)


def symbolic_correctness(
        prediction: str,
        target: str,
        abstol: float = 0,
        reltol: float = 0,
        timeout: int = 10,
) -> bool:
    """Checks if the prediction and target are equivalent latex expressions."""
    # parse_latex can't parse dfrac properly.
    prediction = prediction.replace('\\dfrac', '\\frac')
    target = target.replace('\\dfrac', '\\frac')

    prediction = prediction.strip('$ ')
    target = target.strip('$ ')

    # Factorial sometimes OOMs, will avoid for now.
    if '!' in prediction:
        return prediction == target

    with CustomTimeout(seconds=timeout):
        try:
            prediction_expr = latex.parse_latex(prediction)
            target_expr = latex.parse_latex(target)
            difference = np.abs(sympy.simplify(prediction_expr - target_expr))
            if difference == 0:
                return True
            elif abstol > 0 and reltol > 0:
                min_val = (
                        np.min([
                            np.abs(sympy.simplify(prediction_expr)),
                            np.abs(sympy.simplify(target_expr)),
                        ])
                        + reltol
                )
                abs_condition = difference <= abstol
                rel_condition = difference / min_val <= reltol
                result = abs_condition or rel_condition
                return bool(result)
            else:
                return False
        except TimeoutError:
            return prediction == target
        except sympy.parsing.latex.errors.LaTeXParsingError:
            return prediction == target


def numeric_or_symbolic_correctness(
        prediction: str,
        target: str,
        abstol: float = 1e-4,
        reltol: float = 1e-4,
        symb_abstol: float = 0,
        symb_reltol: float = 0,
) -> bool:
    r"""Checks whether an answer to a problem-set question is correct.

    This function is to be used in, e.g., grading answers to MATH questions where
    it is not known whether the answer type is numeric or symbolic.  It will
    attempt to convert the ground-truth target to a numeric value and compare to
    the numeric value of the prediction; failing the numeric conversion, it will
    grade the prediction as a symbolic expression.

    Args:
    prediction: str, 'student' answer to score (Note: this should not be a long
      string with \\boxed{}, but should already have extracted the boxed
      expression from the ans)
    target: str, ground truth answer (As above, this should be the extracted
      expression itself)
    abstol: float, absolute tolerance for comparing numerical answers
    reltol: float, relative tolerance for comparing numerical answers
    symb_abstol: float, absolute tolerance for comparing symbolic answers
    symb_reltol: float, relative tolerance for comparing symbolic answers

    Returns:
    correct: bool, indicating whether or not the prediction was correct

    E.g.
    numeric_or_symbolic_correctness("4.0", "4.0", 1e-4, 1e-4) -> True
    numeric_or_symbolic_correctness("0.5mv^2", "\\frac{mv^2}{2}") -> True
    """
    if prediction == target:
        return True
    if ',' in prediction or ',' in target:
        return False
    elif to_numeric(prediction) is not None and to_numeric(target) is not None:
        # prediction and target are numeric quantities.
        return numerical_correctness(prediction, target, abstol, reltol)
    else:
        # prediction or target is a symbolic expression (e.g., 2x)
        try:
            return symbolic_correctness(prediction, target, symb_abstol, symb_reltol)
        except:  # pylint: disable=bare-except
            return False


def quick_evaluate_single(dataset_type, solution_or_answer, few_shot_separator, extract_last_occurrence, match_fn, raw_prediction):
    if dataset_type == 'MATH':
        answer_processed = process_sample(solution_or_answer, few_shot_separator, extract_last_occurrence, False)
    else:
        answer_processed = solution_or_answer
    if dataset_type == 'GPQA':
        prediction_processed = process_sample_gpqa(raw_prediction, few_shot_separator, extract_last_occurrence)
    elif dataset_type == 'GSM8K':
        prediction_processed = process_sample(raw_prediction, few_shot_separator, extract_last_occurrence, True)
    else:
        prediction_processed = process_sample(raw_prediction, few_shot_separator, extract_last_occurrence, False)
    prediction_correctness = match_fn(prediction_processed, answer_processed)
    return prediction_correctness


def math_verify_check(expr1, expr2, symmetric=True):
    numeric_precision = 6
    result = math_verify.verify(
        math_verify.parse(f"${expr1}$"),
        math_verify.parse(f"${expr2}$"),
    numeric_precision=numeric_precision)
    if symmetric:
        reversed_result = math_verify.verify(
            math_verify.parse(f"${expr2}$"),
            math_verify.parse(f"${expr1}$"),
        numeric_precision=numeric_precision)
        return result or reversed_result
    else:
        return result


def compute_majority_vote(partitions, gt_answer, equivalence_relation):
    longest_partition = max(partitions, key=len)
    accuracy = 0
    num_longest = 0
    for partition in partitions:
        if len(partition) == len(longest_partition):
            accuracy += equivalence_relation(partition[0], gt_answer)
            num_longest += 1
    return accuracy / num_longest
