from rl.grader_prompts import RL_GRADER_PROMPT_CORRECT_MATH, reasoning_end, reasoning_start, solution_end, solution_start
import os
from typing import List
from openai import OpenAI
from typing import Tuple, Optional
import re 
import json
import math

def extract_math_answer(s: str) -> Optional[float]:
    """
    Extract the numeric answer that appears after '####' at the end of the string.
    Handles numbers with thousands separators like '3,400' -> 3400.

    Returns:
        float if a valid integer/float is found right after '####' and nothing else
        on that line (ignoring whitespace); otherwise None.
    """
    if not isinstance(s, str):
        return None

    # Allow digits and commas in the integer part, plus optional decimal part.
    # Examples matched: 123, 3.5, 3,400, -1,234.56, .5
    pattern = r'####\s*([+-]?(?:[\d,]+(?:\.\d*)?|\.\d+))\s*$'
    match = re.search(pattern, s)

    if not match:
        return None

    raw_number = match.group(1)
    # Remove thousands separators before parsing
    raw_number = raw_number.replace(",", "")

    try:
        return float(raw_number)
    except ValueError:
        return None

match_format = re.compile(
    rf"^[\s]{{0,}}"                      # Optional whitespace at start
    rf"{reasoning_start}.+?{reasoning_end}.*?"  # Reasoning section (non-greedy)
    rf"{solution_start}(.+?){solution_end}"     # Solution section with capture group
    rf"[\s]{{0,}}$",                     # Optional whitespace at end
    flags=re.MULTILINE | re.DOTALL       # Multi-line matching with . matching newlines
)

match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})", # Extract numbers from solution section
    flags=re.MULTILINE | re.DOTALL        # Flexible pattern matching
)

# Reward Function 1: Exact Format Compliance
def match_format_exactly(completions, **kwargs):
    """
    High reward (3.0) for perfect format adherence
    Ensures model learns the complete structured output pattern
    """
    scores = []
    for completion in completions:
        response = completion[0]["content"]
        score = 3.0 if match_format.search(response) is not None else 0.0
        scores.append(score)
    return scores

# Reward Function 2: Partial Format Credit
def match_format_approximately(completions, **kwargs):
    """
    Graduated scoring for format elements
    Encourages learning individual components even if not perfect
    """
    scores = []
    for completion in completions:
        response = completion[0]["content"]
        score = 0
        
        # Award +0.5 for correct token count, -0.5 for wrong count
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end) == 1 else -0.5
        score += 0.5 if response.count(solution_start) == 1 else -0.5
        score += 0.5 if response.count(solution_end) == 1 else -0.5
        
        scores.append(score)
    return scores

# Reward Function 3: Mathematical Accuracy
def check_answer_correctness(prompts, completions, answer, **kwargs):
    """
    Graduated scoring for mathematical accuracy:
    - 3.0: Exact match
    - 1.5: Within 10% (close answer)
    - 0.5: Within 20% (reasonable attempt)
    - -0.5: Wrong answer (penalty for incorrect math)
    """
    responses = [completion[0]["content"] for completion in completions]
    
    # Extract answers using format pattern
    extracted_responses = [
        guess.group(1) if (guess := match_format.search(r)) is not None else None
        for r in responses
    ]
    
    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        true_val = extract_math_answer(true_answer)
        if guess is None:  # No extractable answer
            scores.append(0)
            continue

        try:
            guess = float(guess)
            # Exact string match gets full points
            if guess == true_val:
                scores.append(3.0)
            else:
                # Try numerical comparison for partial credit
                ratio = guess / true_val
                if 0.9 <= ratio <= 1.1:      # Within 10%
                    scores.append(1.5)
                elif 0.8 <= ratio <= 1.2:    # Within 20%
                    scores.append(0.5)
                else:                         # Wrong answer
                    scores.append(-0.5)
        except (ValueError, ZeroDivisionError):
            scores.append(-0.5)           # Invalid numerical format
    
    return scores

# Reward Function 4: Number Extraction Ability  
def check_numbers_extraction(prompts, completions, answer, **kwargs):
    """
    Tests the model's ability to extract numerical values from solution sections
    Complementary to exact format matching - focuses on parsing capability
    """
    responses = [completion[0]["content"] for completion in completions]
    
    # Extract numbers from solution sections using number pattern
    extracted_responses = [
        guess.group(1) if (guess := match_numbers.search(r)) is not None else None
        for r in responses
    ]
    
    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        true_val = extract_math_answer(true_answer)
        if guess is None:  # No extractable number
            scores.append(0)
            continue
            
        try:
            # Simple numerical equality check
            guess_val = float(guess.strip())
            # Binary scoring: correct (1.5) or incorrect (0)
            scores.append(1.5 if guess_val == true_val else 0.0)
        except (ValueError, TypeError):
            scores.append(0)  # Invalid number format
    
    return scores

def reasoning_scores_wgm(scores, weights=None, maximize=None, eps=1e-6):
        """
        Weighted geometric mean (wgm) of reasoning scores
        scores: dict like {"coherence":0.8, ...} each in [0,1]
        weights: dict metric->weight (default 1.0)
        maximize: dict metric->True/False
                True  => higher is better
                False => lower is better (we use 1-score)
        returns: final reward in [0,1]
        """
        if weights is None:
            weights = {k: 1.0 for k in scores}
        if maximize is None:
            # Sensible defaults; adjust to your intentions
            maximize = {
                "coherence": True,
                "premise_accepting": True,
                "maybe_reasoning": True,
                "instruction_deviation": False,      # deviation is bad -> minimize
                "heuristics_reliance": False,        # reliance is often bad -> minimize
                "confirmatory_reasoning": False,     # bias is often bad -> minimize
            }

        total_w = 0.0
        log_sum = 0.0

        for k, s in scores.items():
            w = float(weights.get(k, 1.0))
            if w <= 0:
                continue

            s = max(0.0, min(1.0, float(s)))
            better_high = maximize.get(k, True)
            s_adj = s if better_high else (1.0 - s)

            # avoid log(0)
            s_adj = min(1.0 - eps, max(eps, s_adj))

            log_sum += w * math.log(s_adj)
            total_w += w

        if total_w == 0:
            return 0.0

        return math.exp(log_sum / total_w)

# Matches common ANSI escape sequences (colors, cursor controls, etc.)
_ANSI_RE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")

# Remove literal <think> and </think> tokens anywhere in the string
_THINK_RE = re.compile(r"</?think>")

# Characters that often make a string *look* empty when printed
_ZERO_WIDTH = {"\u200b", "\u200c", "\u200d", "\u2060", "\ufeff"}  # ZWSP/ZWNJ/ZWJ/WJ/BOM
_CONTROL = {chr(i) for i in range(0x00, 0x20)} | {chr(0x7f)}      # ASCII control + DEL


def text_is_empty(s) -> bool:
    """True if s would appear empty when printed (ignoring ANSI, <think> tokens,
    whitespace, zero-width chars, and control chars)."""
    if s == None or s == "None":
        return False
    if not isinstance(s, str):
        return False

    s = _ANSI_RE.sub("", s)
    s = _THINK_RE.sub("", s)

    for ch in s:
        if ch.isspace() or ch in _ZERO_WIDTH or ch in _CONTROL:
            continue
        return False
    return True

def split_reasoning_answer(text: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Split a string of the form:
        "<think> ...reasoning... </think> ...answer..."
    into (reasoning, answer).

    Returns:
        (reasoning, answer)
        - reasoning: content inside <think>...</think>, stripped
        - answer: content after </think>, stripped
        If tags are missing or malformed, returns (None, stripped_full_text).
    """
    # Regex:
    #   - <think> (non-greedy) </think> (rest of text)
    #   - DOTALL so '.' matches newlines
    pattern = re.compile(r"<think>(.*?)</think>(.*)", re.DOTALL | re.IGNORECASE)
    m = pattern.search(text)
    
    if not m:
        # No valid <think>...</think> found: treat whole text as "answer"
        return None, text.strip()
    
    reasoning = m.group(1).strip()
    answer = m.group(2).strip()
    return reasoning, answer

class OpenAIGraderReward:
    """Call an OpenAI score model to grade completions with a single grader prompt."""

    def __init__(
        self,
        api_key: str | None = None,
        model: str = "gpt-4.1-mini",
        grader_type: str = "code_correct",
        is_reasoning_grader: bool = False,
        print_training: bool = False,
        include_answer: bool = False
    ):
        api_key = api_key or os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY must be set for RL grading.")

        self.client = OpenAI(api_key=api_key)
        self.model = model
        self.grader_type = grader_type
        self.is_reasoning_grader = is_reasoning_grader
        self.print_training = print_training
        self.include_answer = include_answer

        self.prompt_template = RL_GRADER_PROMPT_CORRECT_MATH
        #if self.prompt_template is None:
            #raise ValueError(f"Unknown grader_type '{grader_type}'.")

    @staticmethod
    def _extract_first_score(text: str) -> float:
        """
        Extract the first floating point number in [0.0, 1.0] from `text`.
        Returns 0.0 if none is found.
        """
        for token in reversed(text.replace(",", " ").split()):
            try:
                value = float(token)
            except ValueError:
                continue
            if 0.0 <= value <= 1.0:
                return value
        return 0.0

    
    def reward_correct_math(self, prompts, completions, answer, **kwargs) -> list[float]:
        scores_math_format_exactly = match_format_exactly(completions, **kwargs)
        scores_match_format_approximately = match_format_approximately(completions, **kwargs)
        scores_check_answer_correctness = check_answer_correctness(prompts, completions, answer, **kwargs)
        scores_check_numbers_extraction = check_numbers_extraction(prompts, completions, answer, **kwargs)

        responses = [completion[0]["content"] for completion in completions]
        final_scores = []

        for (
            model_response,
            actual_solution,
            score_math_format_exactly,
            score_match_format_approximately,
            score_check_answer_correctness,
            score_check_numbers_extraction,
        ) in zip(
            responses,
            answer,
            scores_math_format_exactly,
            scores_match_format_approximately,
            scores_check_answer_correctness,
            scores_check_numbers_extraction,
        ):
            final_score = (
                score_math_format_exactly
                + score_match_format_approximately
                + score_check_answer_correctness
                + score_check_numbers_extraction
            )
            final_scores.append(final_score)

            if getattr(self, "print_training", False):
                print("=== reward_correct_math ===")
                print("Response:\n", model_response)
                print("Actual answer:\n", actual_solution)
                print("--------------------------------------")
                print("Scores:")
                print("  match_format_exactly:        ", score_math_format_exactly)
                print("  match_format_approximately:  ", score_match_format_approximately)
                print("  check_answer_correctness:    ", score_check_answer_correctness)
                print("  check_numbers_extraction:    ", score_check_numbers_extraction)
                print("--------------------------------------")
                print("Final score:", final_score)
                print()

        return final_scores