"""
Two-stage reasoning scaffold implementation.

This scaffold implements a clean separation between VLM analysis from reasoner logic

```src/reasoning_frameworks/scaffolds/two_stage.py
<code_block_to_apply_changes_from>
"""

import time
from typing import Union, Dict, Any, Optional
from PIL import Image

from ..core.base_scaffold import BaseReasoningScaffold, ReasoningResult
from ..core.vlm_interface import VLMInterface, VLMConfig
from ..core.reasoner_interface import ReasonerInterface, ReasonerConfig
from ..prompts.manager import PromptManager
from ..utils.logging import get_logger
from ..utils.unified_logger import UnifiedReasoningLogger


class TwoStageScaffold(BaseReasoningScaffold):
    """
    Two-stage reasoning scaffold: VLM analysis → Reasoner logic.
    
    This scaffold implements a simple but effective two-stage approach:
    1. VLM analyzes the image and question (tunable component)
    2. Reasoner performs logical reasoning based on VLM analysis (frozen component)
    
    The clear separation allows for targeted VLM training while maintaining
    consistent reasoning capabilities.
    """
    
    def __init__(
        self,
        vlm: VLMInterface,
        reasoner: ReasonerInterface,
        prompt_manager: Optional[PromptManager] = None,
        logger: Optional[UnifiedReasoningLogger] = None,
        **kwargs
    ):
        """
        Initialize two-stage scaffold to match original TwoStageModel.
        
        Args:
            vlm: VLM interface (matches original parameter name)
            reasoner: Reasoner interface (matches original parameter name)
            prompt_manager: Prompt manager for template handling
            logger: Optional unified logger instance
            **kwargs: Additional configuration options
        """
        super().__init__(vlm, reasoner, "two_stage")
        
        self.prompt_manager = prompt_manager or PromptManager()
        
        # Configuration (enable verification by default to match original)
        self.enable_verification = kwargs.get('enable_verification', False)
        self.debug_mode = kwargs.get('debug_mode', True)
        
        # Confidence estimation experiment flags
        self.enable_vlm_confidence = kwargs.get('enable_vlm_confidence', False)
        self.use_confidence_in_reasoner = kwargs.get('use_confidence_in_reasoner', False)
        self.use_logprobs_confidence = kwargs.get('use_logprobs_confidence', False)
        
        # Set up unified logging
        self.unified_logger = logger
        
        # Logger
        self.logger = get_logger(__name__)
        
        # Load prompt templates - use two_stage_math_v1 for confidence experiment support
        try:
            self.prompts = self.prompt_manager.load_prompt_collection("two_stage_math_v1")
            self.logger.info(f"Loaded two_stage_math_v1 prompt collection")
        except Exception as e:
            self.logger.warning(f"Could not load prompt templates: {e}")
            # Fallback to two_stage_v1 if math template not available
            try:
                self.prompts = self.prompt_manager.load_prompt_collection("two_stage_v1")
                self.logger.info(f"Loaded fallback two_stage_v1 prompt collection")
            except Exception as e2:
                self.logger.warning(f"Could not load fallback prompt templates: {e2}")
            self.prompts = None
    
    def solve(
        self,
        image_path: str,
        question: str,
        dataset: str = "UnknownDataset",
        **kwargs
    ) -> Dict[str, Any]:
        """
        Solve method to match the interface expected by TwoStageFrameworkModel.
        
        Args:
            image_path: Path to the image file
            question: Question to answer about the image
            dataset: Dataset name for prompt selection
            **kwargs: Additional parameters including generation_kwargs
            
        Returns:
            Dictionary with answer and debug information
        """
        start_time = time.time()
        generation_kwargs = kwargs.get('generation_kwargs', {})
        
        debug_info = {}
        
        try:
            # Start logging session if logger is available
            session_id = f"two_stage_{dataset}_{int(start_time)}"
            if self.unified_logger:
                session = self.unified_logger.start_session(
                    session_id=session_id,
                    sample_id=dataset,
                    question=question,
                    image_path=image_path,
                    reasoning_approach="two_stage",
                    original_index=None,
                    dataset=dataset
                )
            
            # Stage 1: VLM Initial Description
            desc_start_time = time.time()
            desc_result = self._stage1_description(image_path, question, generation_kwargs)
            desc_end_time = time.time()
            
            if not desc_result['success']:
                if self.unified_logger:
                    self.unified_logger.finish_session(
                        final_answer=f"Error in VLM description: {desc_result.get('error', 'Unknown error')}",
                        success=False,
                        termination_reason="vlm_description_failed",
                        error=desc_result.get('error'),
                        session_id=session_id
                    )
                return {
                    'answer': f"Error in VLM description: {desc_result.get('error', 'Unknown error')}",
                    'raw_output': f"Error in VLM description: {desc_result.get('error', 'Unknown error')}",
                    'success': False,
                    'debug_info': debug_info
                }
            
            debug_info['vlm_prompt'] = desc_result['prompt']
            debug_info['vlm_initial_response'] = desc_result['description']
            debug_info['vlm_description_time'] = f"{desc_end_time - desc_start_time:.2f}s"
            debug_info['vlm_used_confidence_prompt'] = desc_result.get('used_confidence_prompt', False)
            
            # Store confidence data for experiment analysis
            confidence_score = desc_result.get('confidence')
            if confidence_score is not None:
                debug_info['vlm_confidence_score'] = confidence_score
                self.logger.info(f"VLM confidence score: {confidence_score}/100")
            else:
                debug_info['vlm_confidence_score'] = None
            
            # Store raw VLM response for analysis
            debug_info['vlm_raw_response'] = desc_result.get('raw_vlm_response', desc_result['description'])
            
            initial_description = desc_result['description']
            
            # Stage 1.5: VLM Verification (if enabled)
            if self.enable_verification:
                ver_start_time = time.time()
                verification_result = self._stage1_5_verification(
                    image_path, question, initial_description, generation_kwargs
                )
                ver_end_time = time.time()
                
                if verification_result['success']:
                    debug_info['vlm_verification_prompt'] = verification_result['prompt']
                    debug_info['vlm_verification_response'] = verification_result['verification_output']
                    debug_info['vlm_verification_time'] = f"{ver_end_time - ver_start_time:.2f}s"
                    final_description = verification_result['final_description']
                else:
                    debug_info['vlm_verification_response'] = f"Verification failed: {verification_result.get('error', 'Unknown error')}"
                    debug_info['vlm_verification_time'] = f"{ver_end_time - ver_start_time:.2f}s"
                    final_description = initial_description
            else:
                debug_info['vlm_verification_response'] = "Verification disabled"
                debug_info['vlm_verification_time'] = "N/A"
                final_description = initial_description
            
            debug_info['vlm_final_description'] = final_description
            
            # Stage 2: Reasoner Logic
            reasoner_start_time = time.time()
            reasoning_result = self._stage2_reasoning(
                question, final_description, dataset, generation_kwargs, confidence=confidence_score
            )
            reasoner_end_time = time.time()
            
            if not reasoning_result['success']:
                if self.unified_logger:
                    self.unified_logger.finish_session(
                        final_answer=f"Error in reasoning: {reasoning_result.get('error', 'Unknown error')}",
                        success=False,
                        termination_reason="reasoning_failed",
                        error=reasoning_result.get('error'),
                        session_id=session_id
                    )
                return {
                    'answer': f"Error in reasoning: {reasoning_result.get('error', 'Unknown error')}",
                    'raw_output': f"Error in reasoning: {reasoning_result.get('error', 'Unknown error')}",
                    'success': False,
                    'debug_info': debug_info
                }
            
            debug_info['reasoner_prompt'] = reasoning_result['prompt']
            debug_info['reasoner_response'] = reasoning_result['raw_output']
            debug_info['reasoner_time'] = f"{reasoner_end_time - reasoner_start_time:.2f}s"
            debug_info['reasoning_content'] = reasoning_result['reasoning']
            debug_info['reasoner_used_confidence_prompt'] = reasoning_result.get('used_confidence_prompt', False)
            debug_info['reasoner_confidence_used'] = reasoning_result.get('confidence_used')
            
            final_answer = reasoning_result['answer']
            
            # Return format that matches original TwoStageModel
            result = {
                'answer': final_answer,
                'raw_output': reasoning_result['raw_output'],  # For backwards compatibility
                'success': True,
                'debug_info': debug_info
            }
            
            if self.unified_logger:
                self.unified_logger.log_step(
                    step_name="Reasoner Problem Solving",
                    step_type="reasoner",
                    input_data={
                        "question": question,
                        "image_description": final_description
                    },
                    output_data={
                        "answer": final_answer,
                        "reasoning": reasoning_result['reasoning']
                    },
                    runtime=reasoner_end_time - reasoner_start_time,
                    success=True,
                    error=None,
                    metadata={"prompt": reasoning_result['prompt']},
                    session_id=session_id
                )
                self.unified_logger.finish_session(
                    final_answer=final_answer,
                    success=True,
                    termination_reason="completed",
                    session_id=session_id
                )
            
            return result
            
        except Exception as e:
            processing_time = time.time() - start_time
            self.logger.error(f"Two-stage reasoning failed: {e}")
            
            if self.unified_logger:
                self.unified_logger.finish_session(
                    final_answer=f"Error: {str(e)}",
                    success=False,
                    termination_reason="unexpected_error",
                    error=str(e),
                    session_id=session_id if 'session_id' in locals() else None
                )
            
            return {
                'answer': f"Error: {str(e)}",
                'raw_output': f"Error: {str(e)}",
                'success': False,
                'debug_info': debug_info
            }

    def reason(
        self,
        image: Union[str, Image.Image],
        question: str,
        **kwargs
    ) -> ReasoningResult:
        """
        Perform two-stage reasoning on the image and question.
        
        Stage 1: VLM analyzes the image and question
        Stage 2: Reasoner performs logical analysis
        
        Args:
            image: Input image (path or PIL Image)
            question: Question to answer about the image
            **kwargs: Additional parameters
            
        Returns:
            ReasoningResult with complete reasoning trace
        """
        start_time = time.time()
        self.total_queries += 1
        
        try:
            # Start logging session if logger is available
            if self.unified_logger:
                session = self.unified_logger.start_session(
                    session_id=f"two_stage_{'Unknown'}_{int(start_time)}",
                    sample_id='Unknown',
                    question=question,
                    image_path=image,
                    reasoning_approach="two_stage",
                    original_index=None,
                    dataset='Unknown'
                )
            
            # Stage 1: VLM Analysis
            self.logger.debug("Starting Stage 1: VLM Analysis")
            vlm_response = self._stage1_vlm_analysis(image, question, **kwargs)
            
            # Stage 2: Reasoner Logic
            self.logger.debug("Starting Stage 2: Reasoner Logic")
            reasoner_response = self._stage2_reasoner_logic(vlm_response, question, **kwargs)
            
            # Extract final answer
            final_answer = self._extract_final_answer(reasoner_response)
            
            # Optional verification stage
            verification_result = None
            if self.enable_verification:
                self.logger.debug("Starting verification")
                verification_result = self._verify_answer(vlm_response, question, final_answer)
            
            # Create result
            processing_time = time.time() - start_time
            
            result = ReasoningResult(
                final_answer=final_answer,
                reasoning_steps=[
                    f"Stage 1 (VLM): {vlm_response}",
                    f"Stage 2 (Reasoner): {reasoner_response}"
                ],
                success=True,
                vlm_initial_response=vlm_response,
                reasoner_analysis=reasoner_response,
                reasoner_steps=[reasoner_response],
                scaffold_type="two_stage",
                total_iterations=1,
                processing_time=processing_time,
                debug_info={
                    'verification_result': verification_result,
                    'vlm_template': self.vlm_template,
                    'reasoner_template': self.reasoner_template,
                } if self.debug_mode else None
            )
            
            self.successful_queries += 1
            self.logger.info(f"Two-stage reasoning completed successfully in {processing_time:.2f}s")
            
            if self.unified_logger:
                self.unified_logger.finish_session(
                    final_answer=final_answer,
                    success=True,
                    termination_reason="completed"
                )
            
            return result
            
        except Exception as e:
            processing_time = time.time() - start_time
            self.logger.error(f"Two-stage reasoning failed: {e}")
            
            if self.unified_logger:
                self.unified_logger.finish_session(
                    final_answer=f"Error: {str(e)}",
                    success=False,
                    termination_reason="unexpected_error",
                    error=str(e)
                )
            
            return ReasoningResult(
                final_answer="",
                reasoning_steps=[],
                success=False,
                scaffold_type="two_stage",
                processing_time=processing_time,
                error_message=str(e)
            )
    


    def _extract_final_answer(self, reasoner_response: str) -> str:
        """
        Extract the final answer from reasoner response.
        
        This method attempts to find the final answer in the reasoning output.
        It looks for common patterns like "Answer:", "Final answer:", etc.
        
        Args:
            reasoner_response: Full response from reasoner
            
        Returns:
            Extracted final answer
        """
        response = reasoner_response.strip()
        
        # Common answer patterns
        patterns = [
            r"Final answer:\s*(.+)",
            r"Answer:\s*(.+)",
            r"The answer is\s*(.+)",
            r"Therefore,?\s*(.+)",
            r"\\boxed\{([^}]+)\}",  # LaTeX boxed answers
        ]
        
        import re
        for pattern in patterns:
            match = re.search(pattern, response, re.IGNORECASE | re.MULTILINE)
            if match:
                answer = match.group(1).strip()
                # Clean up common formatting
                answer = re.sub(r'[.!]*$', '', answer)  # Remove trailing punctuation
                return answer
        
        # If no pattern found, try to get the last meaningful line
        lines = [line.strip() for line in response.split('\n') if line.strip()]
        if lines:
            last_line = lines[-1]
            # If it's reasonably short, it might be the answer
            if len(last_line) < 100:
                return last_line
        
        # Fallback: return the entire response (truncated)
        if len(response) > 200:
            return response[:200] + "..."
        return response
    
    @classmethod
    def from_config(cls, config_path: str) -> 'TwoStageScaffold':
        """
        Create TwoStageScaffold from configuration file.
        
        Args:
            config_path: Path to configuration file
            
        Returns:
            Configured TwoStageScaffold instance
        """
        import json
        
        with open(config_path, 'r') as f:
            config = json.load(f)
        
        # Create VLM interface
        vlm_config = VLMConfig.from_dict(config['vlm_config'])
        vlm_interface = VLMInterface.create(vlm_config)
        
        # Create reasoner interface
        reasoner_config = ReasonerConfig.from_dict(config['reasoner_config'])
        reasoner_interface = ReasonerInterface.create(reasoner_config)
        
        # Create scaffold
        scaffold_config = config.get('scaffold_config', {})
        
        return cls(
            vlm_interface=vlm_interface,
            reasoner_interface=reasoner_interface,
            **scaffold_config
        )
    
    def save_config(self, config_path: str):
        """Save scaffold configuration to file."""
        config = {
            'scaffold_name': self.scaffold_name,
            'vlm_config': self.vlm.get_config(),
            'reasoner_config': self.reasoner.get_config(),
            'scaffold_config': {
                'vlm_template': self.vlm_template,
                'reasoner_template': self.reasoner_template,
                'enable_verification': self.enable_verification,
                'debug_mode': self.debug_mode,
            }
        }
        
        import json
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)

    def reason_from_description(
        self,
        description: str,
        question: str,
        dataset: str = "MathDataset",
        prompt_template_name: str = "two_stage_math_v1",
        generation_kwargs: Optional[Dict[str, Any]] = None,
        confidence: Optional[int] = None
    ) -> Dict[str, Any]:
        """
        Run two-stage reasoning from a given description (Part 2 for reward function).
        
        This method runs only the reasoning part of the two-stage scaffold:
        1. Takes a description as input (skipping VLM stage entirely)
        2. Runs stage 2 reasoning with the reasoner
        3. Handles continuation if response is incomplete
        4. Extracts the final answer
        
        Args:
            description: Image description (from external captioner)
            question: Question about the image  
            dataset: Dataset type for prompt selection (default: "MathDataset")
            prompt_template_name: Template name to use (default: "two_stage_math_v1")
            generation_kwargs: Generation parameters for reasoner
            
        Returns:
            Dict with 'answer', 'success', 'reasoning', etc.
        """
        try:
            start_time = time.time()
            
            # Load the specified prompt template (math-specific)
            if prompt_template_name != "two_stage_v1":
                try:
                    math_prompts = self.prompt_manager.load_prompt_collection(prompt_template_name)
                    self.logger.info(f"Loaded {prompt_template_name} prompt collection for reasoning")
                    # Temporarily use math prompts
                    original_prompts = self.prompts
                    self.prompts = math_prompts
                except Exception as e:
                    self.logger.warning(f"Could not load {prompt_template_name} prompts, using default: {e}")
                    original_prompts = None
            else:
                original_prompts = None
            
            # Stage 2: Reasoning only (skip VLM entirely)
            reasoning_result = self._stage2_reasoning(
                question=question,
                image_description=description,
                dataset=dataset,
                generation_kwargs=generation_kwargs,
                confidence=confidence
            )
            
            if not reasoning_result['success']:
                if original_prompts:
                    self.prompts = original_prompts
                return {
                    'answer': f"Error in reasoning: {reasoning_result.get('error', 'Unknown error')}",
                    'success': False,
                    'error': reasoning_result.get('error'),
                    'total_time': time.time() - start_time,
                    'reasoning': ''
                }
            
            raw_output = reasoning_result['raw_output']
            
            # Check for incomplete response and try continuation if needed
            if not self._has_complete_response(raw_output):
                self.logger.warning("Reasoner response appears incomplete, attempting continuation...")
                
                continuation_result = self._continue_reasoning(
                    question=question,
                    description=description,
                    incomplete_reasoning=raw_output,
                    generation_kwargs=generation_kwargs
                )
                
                if continuation_result['success']:
                    # Use the continued response
                    raw_output = continuation_result['raw_output']
                    reasoning_result.update(continuation_result)
            
            # Extract final answer using LLM-based extraction (similar to adaptive)
            extracted_answer = self._extract_final_answer_llm(
                question=question,
                description=description,
                reasoning_response=raw_output,
                generation_kwargs=generation_kwargs
            )
            
            # Restore original prompts if we switched
            if original_prompts:
                self.prompts = original_prompts
            
            return {
                'answer': extracted_answer,
                'success': True,
                'reasoning': reasoning_result.get('reasoning', raw_output),
                'total_time': time.time() - start_time,
                'iterations': 1,
                'termination_reason': 'completed'
            }
            
        except Exception as e:
            # Restore original prompts on error
            if 'original_prompts' in locals() and original_prompts:
                self.prompts = original_prompts
                
            error_msg = f"Unexpected error in two-stage reasoning from description: {str(e)}"
            self.logger.error(error_msg, exc_info=True)
            
            return {
                'answer': f"Error: {error_msg}",
                'success': False,
                'error': error_msg,
                'total_time': time.time() - start_time,
                'iterations': 0
            }

    def _has_complete_response(self, response: str) -> bool:
        """Check if the reasoning response appears complete (matching adaptive logic)."""
        if not response:
            return False
        
        # Use same logic as adaptive scaffold's _is_abrupt_cutoff (but inverted)
        # For two-stage, we don't check for "Status:" but for proper completion tags
        response_lower = response.lower()
        
        # Check for clear completion indicators (structured response format)
        has_answer_tag = '<answer>' in response_lower and '</answer>' in response_lower
        has_think_tag = '<think>' in response_lower and '</think>' in response_lower
        
        # Also check for "Final Answer:" pattern as used in two-stage prompts
        has_final_answer = 'final answer:' in response_lower
        
        # Consider complete if it has proper structured tags or final answer
        return has_answer_tag or (has_think_tag and has_final_answer)

    def _continue_reasoning(
        self,
        question: str,
        description: str,
        incomplete_reasoning: str,
        generation_kwargs: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """Continue reasoning from an incomplete response."""
        try:
            if not self.prompts or 'reasoner_continuation_prompt' not in self.prompts:
                # Fallback: just return the incomplete reasoning as-is
                return {
                    'success': False,
                    'error': 'No continuation prompt available',
                    'raw_output': incomplete_reasoning
                }
            
            # Fix prompt formatting issue by escaping {} characters (like in adaptive scaffold)
            continuation_prompt_template = self.prompts['reasoner_continuation_prompt'].replace("boxed{}", "boxed{{}}")
            continuation_prompt = continuation_prompt_template.format(
                question=question,
                description=description,
                incomplete_reasoning=incomplete_reasoning,
                answer='{answer}'
            )
            
            continued_result = self.reasoner.reason(
                context=continuation_prompt,
                **(generation_kwargs or {})
            )
            
            # Combine original and continuation
            combined_response = incomplete_reasoning + "\n\n--- CONTINUATION ---\n" + continued_result
            
            return {
                'success': True,
                'raw_output': combined_response,
                'reasoning': combined_response
            }
            
        except Exception as e:
            self.logger.error(f"Error in continuation: {e}")
            return {
                'success': False,
                'error': str(e),
                'raw_output': incomplete_reasoning
            }

    def _extract_final_answer_llm(
        self,
        question: str,
        description: str,
        reasoning_response: str,
        generation_kwargs: Optional[Dict[str, Any]] = None
    ) -> str:
        """Extract final answer using LLM-based extraction (similar to adaptive scaffold)."""
        try:
            # First try parsing the structured response
            reasoning_content, final_answer = self._parse_reasoner_response(reasoning_response)
            
            # If we got a good answer from parsing, use it
            if final_answer and final_answer != reasoning_response and len(final_answer) < len(reasoning_response):
                return final_answer
            
            # Otherwise, use LLM-based extraction if available
            if self.prompts and 'answer_extraction_prompt' in self.prompts:
                # Fix prompt formatting issue by escaping {} characters (like in adaptive scaffold)
                extraction_prompt_template = self.prompts['answer_extraction_prompt'].replace("boxed{}", "boxed{{}}")
                extraction_prompt = extraction_prompt_template.format(
                    question=question,
                    description=description,
                    reasoning_response=reasoning_response,
                    current_answer=final_answer
                )
                
                extraction_response = self.reasoner.reason(
                    context=extraction_prompt,
                    **(generation_kwargs or {})
                )
                
                # Extract final answer from extraction response (matching adaptive scaffold exactly)
                import re
                patterns = [
                    # Pattern 1: Most specific. Handles LaTeX \boxed{} format, which is common.
                    r"\\{1,4}boxed\{([^}]+)\}",

                    # Pattern 2: Handles "Final Answer: <text>" on a single line, preventing greedy matches.
                    r"\*\*?Final Answer\*\*?[:\s]+([^\n]+)",

                    # Pattern 3: Handles conversational phrases like "The final answer is <text>".
                    r"The final answer is\s+([^.\n]+)",

                    # Pattern 4: A more robust version for "Answer: <text>".
                    r"\b[Aa]nswer\s*:\s*([^\n]+)",
                ]
                
                final_answer = ""
                
                for pattern in patterns:
                    match = re.search(pattern, extraction_response, re.IGNORECASE)
                    if match:
                        candidate = match.group(1).strip()
                        if candidate:
                            self.logger.debug(f"Extracted final answer via pattern '{pattern}': {repr(candidate)}")
                            final_answer = candidate
                            break
                
                # If regex did not succeed, fall back to last non-empty line (matching adaptive exactly)
                if not final_answer:
                    cleaned = extraction_response.strip().split("\n")[-1].strip()
                    final_answer = cleaned
                    self.logger.debug(f"Cleaned extraction response fallback: {repr(final_answer)}")
                
                # Guard: if we ended up with something clearly invalid (matching adaptive exactly)
                if not any(ch.isalnum() for ch in final_answer):
                    if final_answer:  # Only use parsed answer if it exists
                        self.logger.debug("Extracted answer looked invalid – falling back to parsed answer")
                        return final_answer
                
                return final_answer
            
            # Fallback to parsed answer or original response
            return final_answer if final_answer else reasoning_response.strip()
            
        except Exception as e:
            self.logger.warning(f"Answer extraction failed: {e}")
            # Fallback to simple parsing
            reasoning_content, final_answer = self._parse_reasoner_response(reasoning_response)
            return final_answer if final_answer else reasoning_response.strip()

    def _get_vlm_prompt_key(self) -> str:
        """
        Get the appropriate VLM prompt key, handling both naming conventions.
        
        Returns:
            The prompt key to use for VLM description
        """
        # Try both naming conventions for backward compatibility
        if self.prompts:
            if 'vlm_initial_description_prompt' in self.prompts:
                return 'vlm_initial_description_prompt'
            elif 'vlm_description_prompt' in self.prompts:
                return 'vlm_description_prompt'
        
        # Default fallback
        return 'vlm_description_prompt'

    def _stage1_description(
        self,
        image_path: str,
        question: str,
        generation_kwargs: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """Stage 1: Generate initial image description."""
        try:
            # Choose prompt based on confidence experiment flags
            if self.use_logprobs_confidence:
                if not self.prompts or 'vlm_logprobs_confidence_prompt' not in self.prompts:
                    self.logger.warning("VLM logprobs confidence prompt not found, falling back to regular prompt")
                    prompt_key = self._get_vlm_prompt_key()
                    use_logprobs = False
                else:
                    prompt_key = 'vlm_logprobs_confidence_prompt'
                    use_logprobs = True
            elif self.enable_vlm_confidence:
                if not self.prompts or 'vlm_confidence_prompt' not in self.prompts:
                    self.logger.warning("VLM confidence prompt not found, falling back to regular prompt")
                    prompt_key = self._get_vlm_prompt_key()
                else:
                    prompt_key = 'vlm_confidence_prompt'
                use_logprobs = False
            else:
                prompt_key = self._get_vlm_prompt_key()
                use_logprobs = False
            
            if not self.prompts or prompt_key not in self.prompts:
                raise ValueError(
                    f"VLM prompt template '{prompt_key}' not loaded. "
                    "Cannot proceed without proper prompt template. "
                    "Check that template file exists and contains required prompt."
                )
            
            prompt = self.prompts[prompt_key]
            
            # Format the prompt with the question (escape {} characters for math prompts)
            prompt_template = prompt.replace("boxed{}", "boxed{{}}")
            formatted_prompt = prompt_template.format(question=question)
            
            # Generate response (with or without logprobs)
            if use_logprobs:
                self.logger.info(f"🔬 DEBUG - Calling generate_with_logprobs")
                response_data = self.vlm.generate_with_logprobs(
                    image=image_path,
                    prompt=formatted_prompt,
                    **(generation_kwargs or {})
                )
                result = response_data['text']
                logprobs_data = response_data['logprobs']
                self.logger.info(f"🔬 DEBUG - Got response_data: text length={len(result)}, logprobs={logprobs_data is not None}")
            else:
                result = self.vlm.generate(
                    image=image_path,
                    prompt=formatted_prompt,
                    **(generation_kwargs or {})
                )
                logprobs_data = None
            
            # Parse confidence based on method used
            confidence_score = None
            description = result
            
            if self.use_logprobs_confidence and logprobs_data:
                try:
                    description, confidence_score = self._parse_logprobs_confidence_response(result, logprobs_data)
                    self.logger.info(f"🎯 Logprobs confidence: {confidence_score}/100")
                except Exception as e:
                    self.logger.warning(f"Error parsing logprobs confidence: {e}")
                    # Safe fallback: parse as regular confidence response
                    description, confidence_score = self._parse_vlm_confidence_response(result)
            elif self.enable_vlm_confidence:
                try:
                    description, confidence_score = self._parse_vlm_confidence_response(result)
                    # DEBUG: Log raw VLM response to investigate the 95% issue
                    self.logger.info(f"🔍 DEBUG - Raw VLM response: {repr(result[:-200])}...")
                    self.logger.info(f"🔍 DEBUG - Parsed description length: {len(description)}")
                    self.logger.info(f"🔍 DEBUG - Parsed confidence: {confidence_score}")
                except Exception as e:
                    self.logger.warning(f"Error parsing confidence from VLM response: {e}")
                    # Safe fallback: use entire response as description
                    description = result
                    confidence_score = None
            
            return {
                'success': True,
                'description': description,
                'confidence': confidence_score,
                'raw_vlm_response': result,
                'prompt': formatted_prompt,
                'used_confidence_prompt': self.enable_vlm_confidence or self.use_logprobs_confidence,
                'used_logprobs': use_logprobs
            }
            
        except Exception as e:
            self.logger.error(f"Error in stage 1 description: {e}", exc_info=True)
            return {
                'success': False,
                'error': str(e),
                'description': '',
                'confidence': None,
                'raw_vlm_response': '',
                'prompt': prompt if 'prompt' in locals() else '',
                'used_confidence_prompt': False,
                'used_logprobs': False
            }
    
    def _stage1_5_verification(
        self,
        image_path: str,
        question: str,
        initial_description: str,
        generation_kwargs: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """Stage 1.5: Verify and refine the image description."""
        try:
            if not self.prompts or 'vlm_verification_prompt' not in self.prompts:
                raise ValueError(
                    "VLM verification prompt template not loaded. "
                    "Cannot proceed without proper prompt template. "
                    "Check that 'two_stage_v1.yaml' template file exists and contains 'vlm_verification_prompt'."
                )
            
            prompt = self.prompts['vlm_verification_prompt']
            
            # Format the prompt with question and initial description (escape {} characters for math prompts)
            prompt_template = prompt.replace("boxed{}", "boxed{{}}")
            formatted_prompt = prompt_template.format(
                question=question,
                description=initial_description
            )
            
            verification_output = self.vlm.generate(
                image=image_path,
                prompt=formatted_prompt,
                **(generation_kwargs or {})
            )
            
            # Extract final description from verification output
            # Look for specific markers or use the entire output as final description
            final_description = verification_output
            if "FINAL_DESCRIPTION:" in verification_output:
                final_description = verification_output.split("FINAL_DESCRIPTION:")[1].strip()
            elif "Final description:" in verification_output:
                final_description = verification_output.split("Final description:")[1].strip()
            
            return {
                'success': True,
                'verification_output': verification_output,
                'final_description': final_description,
                'prompt': formatted_prompt
            }
            
        except Exception as e:
            self.logger.error(f"Error in stage 1.5 verification: {e}", exc_info=True)
            return {
                'success': False,
                'error': str(e),
                'verification_output': '',
                'final_description': initial_description,
                'prompt': ''
            }
    
    def _stage2_reasoning(
        self,
        question: str,
        image_description: str,
        dataset: str = "UnknownDataset",
        generation_kwargs: Optional[Dict[str, Any]] = None,
        confidence: Optional[int] = None
    ) -> Dict[str, Any]:
        """Stage 2: Solve the problem using the reasoner."""
        try:
            # Get the appropriate reasoner prompt based on dataset type and confidence
            if self.use_confidence_in_reasoner and confidence is not None:
                try:
                    formatted_prompt = self._get_confidence_aware_reasoner_prompt(
                        question, image_description, confidence, dataset
                    )
                    used_confidence_prompt = True
                except Exception as e:
                    self.logger.warning(f"Error getting confidence-aware prompt: {e}, falling back to regular prompt")
                    formatted_prompt = self._get_reasoner_prompt(question, image_description, dataset)
                    used_confidence_prompt = False
            else:
                formatted_prompt = self._get_reasoner_prompt(question, image_description, dataset)
                used_confidence_prompt = False
            
            result = self.reasoner.reason(
                context=formatted_prompt,
                **(generation_kwargs or {})
            )
            
            # Extract answer from result
            answer = result
            reasoning = result
            
            # Try to extract structured answer if the reasoner provides it
            if "ANSWER:" in result:
                parts = result.split("ANSWER:")
                if len(parts) > 1:
                    answer = parts[1].strip()
                    reasoning = parts[0].strip()
            elif "Answer:" in result:
                parts = result.split("Answer:")
                if len(parts) > 1:
                    answer = parts[1].strip()
                    reasoning = parts[0].strip()
            
            return {
                'success': True,
                'answer': answer,
                'reasoning': reasoning,
                'raw_output': result,
                'prompt': formatted_prompt,
                'used_confidence_prompt': used_confidence_prompt,
                'confidence_used': confidence if used_confidence_prompt else None
            }
            
        except Exception as e:
            self.logger.error(f"Error in stage 2 reasoning: {e}", exc_info=True)
            return {
                'success': False,
                'error': str(e),
                'answer': '',
                'reasoning': '',
                'raw_output': '',
                'prompt': prompt if 'prompt' in locals() else '',
                'used_confidence_prompt': False,
                'confidence_used': None
            } 
    def _get_vlm_description_prompt(self, question: str) -> str:
        """Get VLM description prompt."""
        prompt_key = self._get_vlm_prompt_key()
        
        if not self.prompts or prompt_key not in self.prompts:
            raise ValueError(
                f"VLM description prompt template '{prompt_key}' not loaded. "
                "Cannot proceed without proper prompt template. "
                "Check that template file exists and contains required prompt."
            )
        
        prompt = self.prompts[prompt_key]
        prompt_template = prompt.replace("boxed{}", "boxed{{}}")
        return prompt_template.format(question=question)

    def _get_vlm_verification_prompt(self, question: str, description: str) -> str:
        """Get VLM verification prompt."""
        if not self.prompts or 'vlm_verification_prompt' not in self.prompts:
            raise ValueError(
                "VLM verification prompt template not loaded. "
                "Cannot proceed without proper prompt template. "
                "Check that 'two_stage_v1.yaml' template file exists and contains 'vlm_verification_prompt'."
            )
        
        prompt = self.prompts['vlm_verification_prompt']
        prompt_template = prompt.replace("boxed{}", "boxed{{}}")
        return prompt_template.format(question=question, description=description)

    def _get_reasoner_prompt(self, question: str, description: str, dataset: str) -> str:
        """Get reasoner prompt based on dataset type"""
        # Get format instruction
        if not self.prompts or 'reasoner_format_instruction' not in self.prompts:
            raise ValueError(
                "Reasoner format instruction not loaded. "
                "Check that 'two_stage_v1.yaml' template file exists and contains 'reasoner_format_instruction'."
            )

        format_instruction = self.prompts['reasoner_format_instruction']
        
        # Determine prompt type based on dataset (matching original logic)
        if dataset is not None:
            # Check for MCQ datasets
            if self._is_mcq_dataset(dataset):
                if 'reasoner_mcq_prompt' not in self.prompts:
                    raise ValueError("MCQ reasoner prompt not loaded.")
                prompt_template = self.prompts['reasoner_mcq_prompt']
            # Check for math datasets
            elif self._is_math_dataset(dataset):
                if 'reasoner_math_prompt' not in self.prompts:
                    raise ValueError("Math reasoner prompt not loaded.")
                prompt_template = self.prompts['reasoner_math_prompt']
            # Check for detail datasets
            elif self._is_detail_dataset(dataset):
                if 'reasoner_detail_prompt' not in self.prompts:
                    raise ValueError("Detail reasoner prompt not loaded.")
                prompt_template = self.prompts['reasoner_detail_prompt']
            else:
                # Default VQA
                if 'reasoner_default_prompt' not in self.prompts:
                    raise ValueError("Default reasoner prompt not loaded.")
                prompt_template = self.prompts['reasoner_default_prompt']
        else:
            # Default VQA when dataset is None
            if 'reasoner_default_prompt' not in self.prompts:
                raise ValueError("Default reasoner prompt not loaded.")
            prompt_template = self.prompts['reasoner_default_prompt']
        
        # Format the prompt with variables (escape {} characters for math prompts)
        escaped_template = prompt_template.replace("boxed{}", "boxed{{}}")
        return escaped_template.format(
            description=description,
            question=question,
            format_instruction=format_instruction,
            answer='{answer}'
        )
    
    def _is_mcq_dataset(self, dataset: str) -> bool:
        """Check if dataset is MCQ type (matching original logic)."""
        # This would normally use DATASET_TYPE function, but for now we'll use simple heuristics
        # MCQ datasets typically have "MCQ" in their name or are known MCQ datasets
        mcq_patterns = ['mcq', 'choice', 'option']
        return any(pattern in dataset.lower() for pattern in mcq_patterns)
    
    def _is_math_dataset(self, dataset: str) -> bool:
        """Check if dataset is math type (matching original logic)."""
        # Math datasets: MathVista, MathVerse, MM-Math
        math_datasets = ['MathVista', 'MathVerse', 'MM-Math']
        return any(math_dataset in dataset for math_dataset in math_datasets)
    
    def _is_detail_dataset(self, dataset: str) -> bool:
        """Check if dataset requires detailed answers (matching original logic)."""
        # Detail datasets: LLaVABench
        detail_datasets = ['LLaVABench']
        return any(detail_dataset in dataset for detail_dataset in detail_datasets)

    def _process_verification_response(self, initial_description: str, verification_response: str) -> str:
        """Process verification response to extract final description (matching original logic)."""
        # Store the initial description before potential updates
        image_description = initial_description
        
        # Match the exact logic from original two_stage.py
        if "The description is accurate and complete" not in verification_response:
            if "Improved description:" in verification_response:
                corrected_parts = verification_response.split("Improved description:", 1)
                if len(corrected_parts) > 1:
                    image_description = corrected_parts[1].strip()
                # else: Keep original if parsing fails
            elif "accurate and complete" not in verification_response.lower():
                # If not accurate and no "Improved" tag, use the whole response as fallback
                image_description = verification_response
        
        return image_description

    def _parse_reasoner_response(self, reasoner_response: str) -> tuple[str, str]:
        """Parse reasoner response to extract reasoning content and final answer (matching original logic)."""
        import re
        
        BOXED_PREFIX_RE = re.compile(r"\\{1,4}boxed\{")

        def extract_boxed(text: str) -> str | None:
            m = BOXED_PREFIX_RE.search(text)
            if not m:
                return None
            i = m.end()
            depth = 1
            start = i
            while i < len(text) and depth:
                if text[i] == '{':
                    depth += 1
                elif text[i] == '}':
                    depth -= 1
                    if depth == 0:
                        return text[start:i].strip()
                i += 1
            return None
        
        # Extract reasoning content - try both methods (matching original logic)
        reasoning_content = "[No Reasoning Available]"
        
        # Method 1: Parse <think> tags from content
        think_match = re.search(r"<think>(.*?)</think>", reasoner_response, re.DOTALL | re.IGNORECASE)
        if think_match:
            reasoning_content = think_match.group(1).strip()
        
        # Extract final answer from between <answer> tags, if present (matching original logic)
        final_answer_text = reasoner_response
        answer_match = re.search(r"<answer>(.*?)</answer>", reasoner_response, re.DOTALL | re.IGNORECASE)
        if answer_match:
            answer_content = answer_match.group(1).strip()
            # Further parse "Final Answer: " from within <answer> if present
            final_answer_match = re.search(r"Final Answer:(.*)", answer_content, re.IGNORECASE | re.DOTALL)
            if final_answer_match:
                final_answer_text = final_answer_match.group(1).strip()
            else:
                boxed_match = extract_boxed(answer_content)
                if boxed_match:
                    final_answer_text = boxed_match
                else:
                    # If no "Final Answer:" prefix, use entire <answer> content
                    final_answer_text = answer_content
        

        return reasoning_content, final_answer_text

    def _parse_vlm_confidence_response(self, vlm_response: str) -> tuple[str, Optional[int]]:
        """
        Parse VLM response to extract description and confidence score.
        
        Args:
            vlm_response: Raw VLM response potentially containing confidence
            
        Returns:
            Tuple of (description, confidence_score)
            confidence_score is None if parsing fails or confidence not found
        """
        import re
        
        try:
            # Look for the structured format: DESCRIPTION: ... CONFIDENCE: ...
            desc_match = re.search(r"DESCRIPTION:\s*(.*?)(?=CONFIDENCE:|$)", vlm_response, re.DOTALL | re.IGNORECASE)
            conf_match = re.search(r"CONFIDENCE:\s*(\d+)", vlm_response, re.IGNORECASE)
            
            if desc_match and conf_match:
                description = desc_match.group(1).strip()
                confidence = int(conf_match.group(1))
                
                # Validate confidence range
                if 0 <= confidence <= 100:
                    return description, confidence
                else:
                    self.logger.warning(f"Confidence score out of range (0-100): {confidence}")
                    return description, None
            else:
                # Fallback: if structured format not found, use entire response as description
                return vlm_response.strip(), None
                
        except Exception as e:
            self.logger.warning(f"Error parsing VLM confidence response: {e}")
            # Safe fallback: return entire response as description with no confidence
            return vlm_response.strip(), None

    def _parse_logprobs_confidence_response(self, vlm_response: str, logprobs_data: Any) -> tuple[str, Optional[float]]:
        """
        Parse VLM response to extract description and confidence from logprobs.
        
        Args:
            vlm_response: Raw VLM response text
            logprobs_data: Logprobs data from the API response
            
        Returns:
            Tuple of (description, confidence_score)
            confidence_score is None if parsing fails
        """
        import re
        import math
        
        try:
            # Extract description from response text
            desc_match = re.search(r"DESCRIPTION:\s*(.*?)(?=ADEQUATE:|$)", vlm_response, re.DOTALL | re.IGNORECASE)
            adequate_match = re.search(r"ADEQUATE:\s*(Yes|No)", vlm_response, re.IGNORECASE)
            
            if desc_match:
                description = desc_match.group(1).strip()
            else:
                # Fallback: use entire response as description
                description = vlm_response.strip()
            
            if not adequate_match or not logprobs_data:
                self.logger.warning("Could not find ADEQUATE field or missing logprobs data")
                return description, None
            
            # DEBUG: Log token structure for debugging
            self.logger.debug(f"🔬 DEBUG - logprobs_data type: {type(logprobs_data)}")
            
            # Extract the actual Yes/No answer from the text
            actual_answer = adequate_match.group(1).lower()
            self.logger.debug(f"🔬 DEBUG - Actual answer from text: {actual_answer}")
            
            # Extract logprobs for Yes/No tokens
            if hasattr(logprobs_data, 'content') and logprobs_data.content:
                # Find the position of "ADEQUATE:" in the response
                adequate_pos = vlm_response.upper().find("ADEQUATE:")
                if adequate_pos == -1:
                    self.logger.warning("Could not find ADEQUATE position in response")
                    return description, None
                
                # Reconstruct text from tokens to find which token corresponds to Yes/No
                current_pos = 0
                target_token_data = None
                
                for i, token_data in enumerate(logprobs_data.content):
                    if hasattr(token_data, 'token'):
                        token_text = token_data.token
                        token_end_pos = current_pos + len(token_text)
                        
                        # Check if this token overlaps with the Yes/No answer area
                        # Look for tokens after "ADEQUATE:" position
                        if current_pos >= adequate_pos + 9:  # "ADEQUATE:" is 9 chars
                            token_lower = token_text.strip().lower()
                            # Check for Yes/No in the token (exact or partial match)
                            if ('yes' in token_lower or 'no' in token_lower or 
                                token_lower == 'yes' or token_lower == 'no'):
                                target_token_data = token_data
                                self.logger.debug(f"🔬 DEBUG - Found target token: '{token_text}' at position {current_pos}")
                                break
                        
                        current_pos = token_end_pos
                
                # Alternative search: look for any token containing yes/no
                if not target_token_data:
                    for token_data in logprobs_data.content:
                        if hasattr(token_data, 'token'):
                            token_lower = token_data.token.strip().lower()
                            if token_lower in ['yes', 'no'] or 'yes' in token_lower or 'no' in token_lower:
                                target_token_data = token_data
                                self.logger.debug(f"🔬 DEBUG - Found alternative target token: '{token_data.token}'")
                                break
                
                if target_token_data and hasattr(target_token_data, 'top_logprobs'):
                    # Extract p(Yes) and p(No) from top_logprobs
                    p_yes = None
                    p_no = None
                    
                    self.logger.debug(f"🔬 DEBUG - Examining top_logprobs for token: '{target_token_data.token}'")
                    
                    for logprob_entry in target_token_data.top_logprobs:
                        if hasattr(logprob_entry, 'token') and hasattr(logprob_entry, 'logprob'):
                            token = logprob_entry.token.strip().lower()
                            prob = math.exp(logprob_entry.logprob)
                            
                            self.logger.debug(f"🔬 DEBUG - Logprob token: '{logprob_entry.token}' -> prob: {prob:.4f}")
                            
                            if token == 'yes' or 'yes' in token:
                                p_yes = prob
                            elif token == 'no' or 'no' in token:
                                p_no = prob
                    
                    if p_yes is not None and p_no is not None:
                        # Calculate confidence as p(Yes) / (p(Yes) + p(No)) * 100
                        confidence = (p_yes / (p_yes + p_no)) * 100
                        self.logger.debug(f"🎯 Logprobs: p(Yes)={p_yes:.4f}, p(No)={p_no:.4f}, confidence={confidence:.1f}%")
                        return description, round(confidence, 1)
                    elif p_yes is not None:
                        # Only Yes found, high confidence if actual answer is yes
                        confidence = 85.0 if actual_answer == 'yes' else 15.0
                        self.logger.debug(f"🎯 Only p(Yes)={p_yes:.4f} found, using confidence={confidence}%")
                        return description, confidence
                    elif p_no is not None:
                        # Only No found, high confidence if actual answer is no
                        confidence = 85.0 if actual_answer == 'no' else 15.0
                        self.logger.debug(f"🎯 Only p(No)={p_no:.4f} found, using confidence={confidence}%")
                        return description, confidence
                    else:
                        self.logger.warning("Could not find Yes/No probabilities in logprobs")
                        # Fallback: use a simple heuristic based on actual answer
                        confidence = 75.0 if actual_answer == 'yes' else 25.0
                        self.logger.debug(f"🎯 Fallback confidence based on answer '{actual_answer}': {confidence}%")
                        return description, confidence
                else:
                    self.logger.warning("Could not find target token data in logprobs")
                    # Fallback: use a simple heuristic
                    confidence = 75.0 if actual_answer == 'yes' else 25.0
                    self.logger.debug(f"🎯 No token data fallback confidence: {confidence}%")
                    return description, confidence
            else:
                self.logger.warning("Logprobs data has unexpected structure")
                return description, None
                
        except Exception as e:
            self.logger.warning(f"Error parsing logprobs confidence response: {e}")
            return vlm_response.strip(), None

    def _get_confidence_aware_reasoner_prompt(self, question: str, description: str, confidence: int, dataset: str) -> str:
        """Get confidence-aware reasoner prompt based on dataset type."""
        # Get format instruction
        if not self.prompts or 'reasoner_format_instruction' not in self.prompts:
            raise ValueError(
                "Reasoner format instruction not loaded. "
                "Check that prompt template file exists and contains 'reasoner_format_instruction'."
            )

        format_instruction = self.prompts['reasoner_format_instruction']
        
        # Determine prompt type based on dataset (matching original logic)
        prompt_key = None
        if dataset is not None:
            # Check for MCQ datasets
            if self._is_mcq_dataset(dataset):
                prompt_key = 'reasoner_confidence_mcq_prompt'
            # Check for math datasets
            elif self._is_math_dataset(dataset):
                prompt_key = 'reasoner_confidence_math_prompt'
            # Check for detail datasets
            elif self._is_detail_dataset(dataset):
                prompt_key = 'reasoner_confidence_detail_prompt'
            else:
                # Default VQA
                prompt_key = 'reasoner_confidence_default_prompt'
        else:
            # Default VQA when dataset is None
            prompt_key = 'reasoner_confidence_default_prompt'
        
        if not self.prompts or prompt_key not in self.prompts:
            self.logger.warning(f"Confidence-aware prompt '{prompt_key}' not found, falling back to regular prompt")
            return self._get_reasoner_prompt(question, description, dataset)
        
        prompt_template = self.prompts[prompt_key]
        
        # Format the prompt with variables (escape {} characters for math prompts)
        escaped_template = prompt_template.replace("boxed{}", "boxed{{}}")
        return escaped_template.format(
            description=description,
            confidence=confidence,
            question=question,
            format_instruction=format_instruction,
            answer='{answer}'
        )



