import re
from typing import Dict, Any, Optional

from sympy import sympify, simplify, Rational
from sympy.parsing.latex import parse_latex
from sympy.parsing.latex.errors import LaTeXParsingError
from pylatexenc.latex2text import LatexNodes2Text


def _parse_expression(expr_str: str):
   
    try:
       
        expr_str = re.sub(r'\\text{.*?}', '', expr_str).strip()
        
        
        
        
        return parse_latex(expr_str)
    
    except (LaTeXParsingError, TypeError, Exception):
        try:
           
            
            return sympify(expr_str, rational=True)
           
        
        except (SyntaxError, TypeError, Exception):
            
            return None

def is_equiv(predicted: str, target: str) -> bool:

    if predicted == target:
        return True

    pred_expr = _parse_expression(predicted)
    
    target_expr = _parse_expression(target)
    

    if pred_expr is None or target_expr is None:
        return False
        
    try:
        
        difference = simplify(pred_expr - target_expr)

        if difference == 0:
            return True
        if hasattr(difference, 'is_Float') and difference.is_Float and abs(difference) < 1e-9:
            return True
            
        return False
    except Exception:
        return False
    
class Accuracy:
    def __init__(self):
        self._num_correct = 0
        self._num_total = 0
    
    def update(self, predicted: str, target: str) -> None:

        # tolerance = 1e-6
        # if abs(predicted - target) < tolerance:
        #     is_correct = True
        is_correct = predicted == target  
        # else:
        #     is_correct = False
        self._num_correct += int(is_correct)
        self._num_total += 1

    def get(self) -> float:
        return self._num_correct / self._num_total

    def print(self):
        accuracy = self.get()
        print(f"Accuracy: {accuracy*100:.1f}% "
              f"({self._num_correct}/{self._num_total})")
        
class MATH_Accuracy(Accuracy):
    def __init__(self):
        super().__init__()

    def update(self, predicted: str, target: str) -> None: 
        
        self._num_total += 1
        
      
        is_correct = is_equiv(predicted, target) 
        
        self._num_correct += int(is_correct)



class Travel_Accuracy(Accuracy):

    def __init__(self):
        super().__init__(
            name="Travel_Accuracy",
            description=''
        )

    def update(self, answer: str, correct_answer: str) -> str:

        

        #return evaluate_travel_plan(answer, correct_answer)
        return "accuracy"

class Drop_Accuracy(Accuracy):

    def __init__(self):
        super().__init__(
            name="Drop_Accuracy",
            description=""
        )

    

import re
import string
from collections import Counter


import re
import string
from collections import Counter
from typing import Tuple, List, Set, Union, Optional
import numpy as np # For max and mean operations if needed, but primary F1 is per sample


class F1_Score:
    def __init__(self):
        self._num_total = 0
        self._f1_score_sum = 0.0

   
    def _remove_articles(self, text: str) -> str:
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def _white_space_fix(self, text: str) -> str:
        return " ".join(text.split())

    EXCLUDE = set(string.punctuation)

    def _remove_punc(self, text: str) -> str:
       
        if not self._is_number(text):
            return "".join(ch for ch in text if ch not in self.EXCLUDE)
        else:
            return text

    def _lower(self, text: str) -> str:
        return text.lower()

    def _tokenize(self, text: str) -> List[str]:
        
        return re.split(" |-", text)

    def _is_number(self, text: str) -> bool:
        try:
            float(text)
            return True
        except ValueError:
            return False

    def _normalize_number(self, text: str) -> str:
        if self._is_number(text):
            return str(float(text))
        else:
            return text

    def normalize_answer(self, text: str) -> str:
        """
        Lower text and remove punctuation, articles and extra whitespace.
        Applies number normalization.
        This version is aligned with DROP's _normalize_answer.
        """
        parts = [
            self._white_space_fix(self._remove_articles(self._normalize_number(self._remove_punc(self._lower(token)))))
            for token in self._tokenize(text)
        ]
        parts = [part for part in parts if part.strip()] # Filter out empty strings
        normalized = " ".join(parts).strip()
        return normalized

    
    def _compute_f1_for_bags(self, predicted_bag: Set[str], gold_bag: Set[str]) -> float:
        """Helper to compute F1 between two sets of tokens."""
        intersection = len(gold_bag.intersection(predicted_bag))
        if not predicted_bag:
            precision = 1.0
        else:
            precision = intersection / float(len(predicted_bag))
        if not gold_bag:
            recall = 1.0
        else:
            recall = intersection / float(len(gold_bag))
        f1 = (
            (2 * precision * recall) / (precision + recall)
            if not (precision == 0.0 and recall == 0.0)
            else 0.0
        )
        return f1

    def _answer_to_bag(self, answer_text: str) -> Set[str]:
        """Converts a single answer string to a normalized token bag (set of tokens)."""
        normalized_span = self.normalize_answer(answer_text)
        return set(normalized_span.split())

    def update(self, predicted: str, target: Tuple[List[str], List[str]]) -> None:
        """
        Update F1 score based on a single prediction and target.
        Target is a tuple of (spans_list, types_list).
        Calculates max F1 between predicted output and any of the target spans.
        """
        # Unpack target:
        ground_truth_spans: List[str] = target[0]
     
        predicted_parts = predicted.split("|") # Allow multiple predicted answers separated by '|'

        max_f1_for_this_sample = 0.0

        for gt_span in ground_truth_spans:
            if gt_span.strip() == "": # Skip empty ground truth spans
                continue

            # Convert current ground truth span to a token bag
            gold_bag = self._answer_to_bag(gt_span)

            # Compare current ground truth bag with all predicted parts
            for pred_part in predicted_parts:
                pred_bag = self._answer_to_bag(pred_part)

             
                current_f1 = self._compute_f1_for_bags(pred_bag, gold_bag)
                max_f1_for_this_sample = max(max_f1_for_this_sample, current_f1)

        self._f1_score_sum += max_f1_for_this_sample
        self._num_total += 1

    def get(self) -> float:
        """Get the average F1 score."""
        if self._num_total == 0:
            return 0.0
        return self._f1_score_sum / self._num_total

    def print(self):
        """Print the average F1 score."""
        f1_score = self.get()
        print(
            f"F1 Score: {f1_score:.3f} "
            f"(Total samples: {self._num_total})"
        )

