import math
import re
import ast
import operator

from openai import AsyncOpenAI

from tenacity import (
    retry,
    wait_random_exponential,
)  # for exponential backoff

openai = AsyncOpenAI()


class OpenAiJudge:
    """OpenAI models tokenize all numbers from 0-100 as single tokens, which is why we can get exactly
    one completion token with logprobs. Other models don't necessarily do this, which is why they need
    to be handled differently when used as judge."""

    def __init__(self, model: str, prompt_template: str):
        self.model = model
        self.prompt_template = prompt_template

    async def judge(self, **kwargs):
        messages = [dict(role="user", content=self.prompt_template.format(**kwargs))]
        logprobs = await self.logprob_probs(messages)
        score = self._aggregate_0_100_score(logprobs)
        return score

    @retry(wait=wait_random_exponential(min=100, max=600))
    async def logprob_probs(self, messages) -> dict:
        """Simple logprobs request. Returns probabilities. Always samples 1 token."""
        completion = await openai.chat.completions.create(
            model=self.model,
            messages=messages,
            max_tokens=1,
            temperature=0,
            logprobs=True,
            top_logprobs=20,
            seed=0,
        )
        try:
            logprobs = completion.choices[0].logprobs.content[0].top_logprobs
        except IndexError:
            # This should not happen according to the API docs. But it sometimes does.
            return {}

        result = {}
        for el in logprobs:
            result[el.token] = float(math.exp(el.logprob))

        return result

    def _aggregate_0_100_score(self, score: dict) -> float:
        #   NOTE: we don't check for refusals explcitly. Instead we assume that
        #   if there's at least 0.25 total weight on numbers, it's not a refusal.
        total = 0
        sum_ = 0
        for key, val in score.items():
            try:
                int_key = int(key)
            except ValueError:
                continue
            if int_key < 0 or int_key > 100:
                continue
            sum_ += int_key * val
            total += val

        if total < 0.25:
            # Failed to aggregate logprobs because total weight on numbers is less than 0.25.
            return None
        return sum_ / total

    async def __call__(self, **kwargs):
        return await self.judge(**kwargs)


class ArithmeticJudge:
    """Judge that extracts numerical expressions from the last line of questions, 
    evaluates them, and compares with model answers for exact match scoring."""
    
    def __init__(self, model: str = None, prompt_template: str = None, tier: int = 0):
        # model and prompt_template are ignored for exact match judge
        # but kept for interface compatibility
        self.tier = tier
    
    def _safe_eval(self, expression):
        """Safely evaluate a numerical expression."""
        # Remove whitespace
        expression = expression.strip()
        
        # Swap operators according to tier
        # This affects how we interpret operators when parsing the expression
        
        # Try to parse as a simple number first
        try:
            return float(expression)
        except ValueError:
            pass
        
        # For safety, only allow basic mathematical operations
        # This is a simple implementation - you might want to extend it
        allowed_names = {
            'abs': abs, 'round': round, 'min': min, 'max': max,
            'sum': sum, 'pow': pow, 'sqrt': math.sqrt,
            'sin': math.sin, 'cos': math.cos, 'tan': math.tan,
            'log': math.log, 'exp': math.exp, 'pi': math.pi, 'e': math.e
        }
        
        # Apply operator swapping based on tier
        if self.tier == 0:
            # Tier 0: No permutation (normal)
            allowed_operators = {
                ast.Add: operator.add, ast.Sub: operator.sub,
                ast.Mult: operator.mul, ast.Div: operator.truediv,
                ast.Pow: operator.pow, ast.Mod: operator.mod,
                ast.FloorDiv: operator.floordiv,
                ast.USub: operator.neg, ast.UAdd: operator.pos
            }
        elif self.tier == 1:
            # Tier 1: Swap + and * (+ becomes *, * becomes +)
            allowed_operators = {
                ast.Add: operator.mul, ast.Sub: operator.sub,
                ast.Mult: operator.add, ast.Div: operator.truediv,
                ast.Pow: operator.pow, ast.Mod: operator.mod,
                ast.FloorDiv: operator.floordiv,
                ast.USub: operator.neg, ast.UAdd: operator.pos
            }
        elif self.tier == 2:
            # Tier 2: Swap + with * and - with / (+ ↔ *, - ↔ /)
            allowed_operators = {
                ast.Add: operator.mul, ast.Sub: operator.truediv,
                ast.Mult: operator.add, ast.Div: operator.sub,
                ast.Pow: operator.pow, ast.Mod: operator.mod,
                ast.FloorDiv: operator.floordiv,
                ast.USub: operator.neg, ast.UAdd: operator.pos
            }
        elif self.tier == 3:
            # Tier 3: Cycle permutation (+ → -, - → *, * → /, / → +)
            allowed_operators = {
                ast.Add: operator.sub, ast.Sub: operator.mul,
                ast.Mult: operator.truediv, ast.Div: operator.add,
                ast.Pow: operator.pow, ast.Mod: operator.mod,
                ast.FloorDiv: operator.floordiv,
                ast.USub: operator.neg, ast.UAdd: operator.pos
            }
        else:
            # Default to tier 0 behavior for unknown tiers
            allowed_operators = {
                ast.Add: operator.add, ast.Sub: operator.sub,
                ast.Mult: operator.mul, ast.Div: operator.truediv,
                ast.Pow: operator.pow, ast.Mod: operator.mod,
                ast.FloorDiv: operator.floordiv,
                ast.USub: operator.neg, ast.UAdd: operator.pos
            }
        
        try:
            # Parse the expression
            node = ast.parse(expression, mode='eval')
            
            def _eval_node(node):
                if isinstance(node, ast.Expression):
                    return _eval_node(node.body)
                elif isinstance(node, ast.Constant):
                    return node.value
                elif isinstance(node, ast.Num):  # Python < 3.8 compatibility
                    return node.n
                elif isinstance(node, ast.Name):
                    if node.id in allowed_names:
                        return allowed_names[node.id]
                    else:
                        raise ValueError(f"Name '{node.id}' not allowed")
                elif isinstance(node, ast.BinOp):
                    left = _eval_node(node.left)
                    right = _eval_node(node.right)
                    op = allowed_operators.get(type(node.op))
                    if op is None:
                        raise ValueError(f"Operator {type(node.op)} not allowed")
                    return op(left, right)
                elif isinstance(node, ast.UnaryOp):
                    operand = _eval_node(node.operand)
                    op = allowed_operators.get(type(node.op))
                    if op is None:
                        raise ValueError(f"Unary operator {type(node.op)} not allowed")
                    return op(operand)
                elif isinstance(node, ast.Call):
                    func = _eval_node(node.func)
                    args = [_eval_node(arg) for arg in node.args]
                    return func(*args)
                else:
                    raise ValueError(f"Node type {type(node)} not allowed")
            
            return _eval_node(node)
        except:
            raise ValueError(f"Could not safely evaluate expression: {expression}")
    
    def _extract_numerical_expression(self, question):
        """Extract the numerical expression from the last line of the question."""
        lines = question.strip().split('\n')
        last_line = lines[-1].strip()
        
        # Look for patterns that might contain numerical expressions
        # This could be enhanced based on your specific question format
        
        # Pattern 1: Expression after '=' or ':'
        if '=' in last_line:
            expression = last_line.split('=')[-1].strip()
        elif ':' in last_line:
            expression = last_line.split(':')[-1].strip()
        else:
            # Pattern 2: The entire last line might be the expression
            expression = last_line
        
        return expression
    
    def _extract_number_from_answer(self, answer):
        """Extract numerical value from model's answer."""
        # Remove whitespace and common text
        answer = answer.strip()
        
        # Try to extract numbers from the answer
        # Look for patterns like "The answer is 42" or just "42"
        
        # First, try to find numbers in the text
        numbers = re.findall(r'-?\d+\.?\d*', answer)
        
        if numbers:
            # Take the last number found (often the final answer)
            try:
                return float(numbers[-1])
            except ValueError:
                pass
        
        # If no numbers found, try to evaluate the entire answer as expression
        try:
            return self._safe_eval(answer)
        except:
            return None
    
    def _numbers_are_close(self, a, b, tolerance=1e-6):
        """Check if two numbers are close within floating point tolerance."""
        if a is None or b is None:
            return False
        
        try:
            a, b = float(a), float(b)
            if math.isnan(a) or math.isnan(b):
                return False
            if math.isinf(a) or math.isinf(b):
                return a == b
            return abs(a - b) <= tolerance * max(abs(a), abs(b), 1.0)
        except:
            return False
    
    async def judge(self, question, answer, **kwargs):
        """Judge the answer by comparing with the evaluated expression from the question."""
        try:
            # Extract the numerical expression from the last line
            expression = self._extract_numerical_expression(question)
            
            # Evaluate the expression to get expected answer
            expected = self._safe_eval(expression)
            
            # Extract number from model's answer
            actual = self._extract_number_from_answer(answer)
            
            # Compare and score
            if self._numbers_are_close(expected, actual):
                return 100.0
            else:
                return 0.0
                
        except Exception as e:
            # If anything goes wrong, return 0
            print(f"Error in ExactMatchJudge: {e}")
            print(f"Question: {question}")
            print(f"Answer: {answer}")
            return 0.0
    
    async def __call__(self, **kwargs):
        return await self.judge(**kwargs)


class ExactMatchJudge:
    """Judge that checks if the model's answer is exactly the same as the expected answer."""

    def __init__(self, model: str = None, prompt_template: str = None, tier: int = 0):
        # model and prompt_template are ignored for exact match judge
        # but kept for interface compatibility
        self.tier = tier

    async def judge(self, question, answer, expected_answer=None, **kwargs):
        """Judge the answer by comparing with the expected answer."""
        if expected_answer is None:
            # Fallback behavior if no expected answer provided
            return 0.0
        
        # Simple exact string match
        if answer.strip() == str(expected_answer).strip():
            return 100.0
        else:
            return 0.0
    
    async def __call__(self, **kwargs):
        return await self.judge(**kwargs)