"""
Real-time step boundary detection during generation
"""

import torch
from typing import List, Dict, Optional
from transformers import StoppingCriteria
import re


class StepBoundaryDetector:
    """Detects when a reasoning step is complete during generation"""
    
    def __init__(
        self,
        step_patterns: List[str] = None,
        answer_patterns: List[str] = None,
        max_tokens_per_step: int = 250
    ):
        """
        Args:
            step_patterns: Patterns that indicate step boundaries (e.g., ["\n- Step", "\nStep"])
            answer_patterns: Patterns that indicate final answer (e.g., ["<Answer>:", "\n\nAnswer:"])
            max_tokens_per_step: Maximum tokens allowed per step
        """
        self.step_patterns = step_patterns or [
            "\n- Step",
            "- Step",
            "\nStep", 
            "\n\n",
            "\n**Step",
            "\n## Step",
            "<Answer>:",
            "\n<Answer>:",
            "\n\nAnswer:",
            "\nFinal Answer:",
            "\n\nThe answer is"
        ]
        self.answer_patterns = answer_patterns or [
            "<Answer>:",
            "\n<Answer>:",
            "\n\nAnswer:",
            "\nFinal Answer:",
            "\n\nThe answer is"
        ]
        self.max_tokens_per_step = max_tokens_per_step
        
    def is_step_complete(self, generated_text: str, token_count: int = None) -> bool:
        """Check if current generation represents a complete step"""
        # Immediate completion if we hit an answer pattern - triggers answer phase
        for pattern in self.answer_patterns:
            if pattern in generated_text:
                return True
                
        # Count occurrences of "- Step" pattern specifically
        # We need to see it twice: once at the beginning of current step, once at the beginning of next step
        step_count = generated_text.count("- Step")
        
        # Stop when we see 2 or more "- Step" markers (current step + next step beginning)
        if step_count >= 2:
            return True
                
        # Check token limit
        if token_count and token_count >= self.max_tokens_per_step:
            return True
            
        return False
        
    def is_trajectory_complete(self, generated_text: str, reached_eos: bool = False) -> bool:
        """Check if trajectory is complete (second step marker is answer tag)"""
        # Find all step marker positions
        marker_positions = []
        for pattern in self.step_patterns:
            pos = 0
            while True:
                pos = generated_text.find(pattern, pos)
                if pos == -1:
                    break
                marker_positions.append((pos, pattern))
                pos += len(pattern)
        
        # Sort by position
        marker_positions.sort()
        
        # If we have 2+ markers, check if the second one is an answer pattern
        if len(marker_positions) >= 2:
            second_marker_pattern = marker_positions[1][1]
            # Check if second marker is an answer pattern
            if second_marker_pattern in self.answer_patterns:
                return True
                
        return False
        
    def contains_answer_pattern(self, generated_text: str) -> bool:
        """Check if text contains any answer pattern"""
        for pattern in self.answer_patterns:
            if pattern in generated_text:
                return True
        return False
        
    def extract_step_text(self, generated_text: str) -> str:
        """Extract the step text, removing boundary markers at the END only"""
        step_text = generated_text.strip()
        
        # Special handling for "- Step" pattern
        # If we have 2+ occurrences, remove everything from the second occurrence onwards
        step_count = step_text.count("- Step")
        if step_count >= 2:
            # Find the position of the second "- Step"
            first_pos = step_text.find("- Step")
            second_pos = step_text.find("- Step", first_pos + 1)
            if second_pos != -1:
                step_text = step_text[:second_pos].strip()
        
        # For answer patterns, remove from the first occurrence
        for pattern in self.answer_patterns:
            if pattern in step_text:
                pos = step_text.find(pattern)
                if pos != -1:
                    step_text = step_text[:pos].strip()
                    break
                
        return step_text


class OnlineStepStoppingCriteria(StoppingCriteria):
    """Stopping criteria for online step generation"""
    
    def __init__(
        self, 
        tokenizer,
        start_length: int,
        detector: StepBoundaryDetector
    ):
        self.tokenizer = tokenizer
        self.start_length = start_length
        self.detector = detector
        
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> bool:
        """Check if generation should stop for current step"""
        # Get generated tokens since start
        generated_ids = input_ids[0][self.start_length:]
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        return self.detector.is_step_complete(
            generated_text, 
            token_count=len(generated_ids)
        )


class BatchStepStoppingCriteria(StoppingCriteria):
    """Stopping criteria for batch step generation"""
    
    def __init__(
        self,
        tokenizer, 
        start_length: int,
        detector: StepBoundaryDetector,
        batch_size: int
    ):
        self.tokenizer = tokenizer
        self.start_length = start_length  
        self.detector = detector
        self.batch_size = batch_size
        self.finished = [False] * batch_size
        
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> bool:
        """Check stopping criteria for entire batch"""
        # Check each sequence in batch
        for i in range(min(input_ids.shape[0], self.batch_size)):
            if not self.finished[i]:
                generated_ids = input_ids[i][self.start_length:]
                generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
                
                if self.detector.is_step_complete(
                    generated_text,
                    token_count=len(generated_ids)
                ):
                    self.finished[i] = True
                    
        # Stop when all sequences are finished
        return all(self.finished)


class StepExtractionResult:
    """Result of step extraction from generation"""
    
    def __init__(
        self,
        step_text: str,
        is_complete: bool,
        is_trajectory_complete: bool,
        token_ids: List[int],
        raw_text: str
    ):
        self.step_text = step_text
        self.is_complete = is_complete
        self.is_trajectory_complete = is_trajectory_complete  
        self.token_ids = token_ids
        self.raw_text = raw_text