"""
Three-stage reasoning scaffold implementation.

This scaffold implements a conservative three-stage approach:
1. VLM provides initial description  
2. Reasoner decides if more information is needed (with adaptive-style decision)
3. If needed, VLM provides one focused description, then reasoner solves

This is a middle-ground between two-stage (no questions) and adaptive (multiple questions),
allowing exactly ONE clarifying question when needed.
"""

import time
from typing import Union, Dict, Any, Optional

from ..core.base_scaffold import BaseReasoningScaffold, ReasoningResult
from ..core.vlm_interface import VLMInterface
from ..core.reasoner_interface import ReasonerInterface
from ..prompts.manager import PromptManager
from ..utils.logging import get_logger
from ..utils.unified_logger import UnifiedReasoningLogger


class ThreeStageScaffold(BaseReasoningScaffold):
    """
    Three-stage reasoning scaffold: VLM analysis → Reasoner decision → Optional VLM clarification → Final answer.
    
    This scaffold implements a conservative three-stage approach:
    1. VLM analyzes the image and question (tunable component)
    2. Reasoner decides if more information is needed using adaptive-style decision making
    3. If needed, VLM provides one focused description, then reasoner provides final answer
    
    Unlike adaptive scaffold (multiple iterations), three-stage allows exactly ONE clarifying question.
    This provides a middle-ground between two-stage (no questions) and adaptive (multiple questions).
    """
    
    def __init__(
        self,
        vlm: VLMInterface,
        reasoner: ReasonerInterface,
        prompt_manager: Optional[PromptManager] = None,
        logger: Optional[UnifiedReasoningLogger] = None,
        prompt_template_name: str = "three_stage_math_v1",
        enable_vlm_confidence: bool = False,
        use_confidence_in_reasoner: bool = False,
        **kwargs
    ):
        """
        Initialize three-stage scaffold.
        
        Args:
            vlm: VLM interface (tunable component)
            reasoner: Reasoner interface (frozen component)
            prompt_manager: Prompt manager for template handling
            logger: Optional unified logger instance
            prompt_template_name: Name of prompt template to load
            enable_vlm_confidence: Enable VLM confidence estimation (experimental)
            use_confidence_in_reasoner: Use VLM confidence in reasoner prompts (experimental)
            **kwargs: Additional configuration options
        """
        super().__init__(vlm, reasoner, "three_stage")
        
        self.prompt_manager = prompt_manager or PromptManager()
        self.prompt_template_name = prompt_template_name
        
        # Confidence estimation experiment flags
        self.enable_vlm_confidence = enable_vlm_confidence
        self.use_confidence_in_reasoner = use_confidence_in_reasoner
        
        # Set up unified logging
        self.unified_logger = logger
        
        # Logger
        self.logger = get_logger(__name__)
        
        # Load prompt templates
        try:
            self.prompts = self.prompt_manager.load_prompt_collection(self.prompt_template_name)
            self.logger.info(f"Loaded {self.prompt_template_name} prompt collection")
        except Exception as e:
            self.logger.warning(f"Could not load prompt templates: {e}")
            self.prompts = None

    def reason(
        self,
        image: Union[str, Any],
        question: str,
        **kwargs
    ) -> ReasoningResult:
        """
        Perform reasoning on the given image and question (BaseReasoningScaffold interface).
        
        Args:
            image: Input image (path or PIL Image) 
            question: Question to answer about the image
            **kwargs: Additional parameters (dataset, sample_id, etc.)
            
        Returns:
            ReasoningResult containing the complete reasoning trace
        """
        start_time = time.time()
        
        # Extract parameters from kwargs
        dataset = kwargs.get('dataset', 'Unknown')
        sample_id = kwargs.get('sample_id')
        original_index = kwargs.get('original_index')
        ground_truth = kwargs.get('ground_truth')
        generation_kwargs = kwargs.get('generation_kwargs')
        
        # Convert image to path string if needed
        image_path = str(image) if image else ""
        
        try:
            # Call the existing solve method
            solve_result = self.solve(
                image_path=image_path,
                question=question,
                dataset=dataset,
                sample_id=sample_id,
                original_index=original_index,
                ground_truth=ground_truth,
                generation_kwargs=generation_kwargs,
                **kwargs
            )
            
            # Convert solve result to ReasoningResult
            processing_time = time.time() - start_time
            
            reasoning_steps = []
            if 'reasoning' in solve_result:
                reasoning_steps = [solve_result['reasoning']]
            
            return ReasoningResult(
                final_answer=solve_result.get('answer', ''),
                reasoning_steps=reasoning_steps,
                success=solve_result.get('success', False),
                vlm_initial_response=None,
                vlm_intermediate_responses=None,
                reasoner_analysis=solve_result.get('reasoning', ''),
                reasoner_steps=reasoning_steps,
                scaffold_type="three_stage",
                total_iterations=solve_result.get('iterations', 1),
                processing_time=processing_time,
                debug_info={
                    'termination_reason': solve_result.get('termination_reason', 'unknown'),
                    'total_time': solve_result.get('total_time', processing_time)
                },
                error_message=solve_result.get('error') if not solve_result.get('success', False) else None
            )
            
        except Exception as e:
            processing_time = time.time() - start_time
            error_msg = f"Error in three-stage reasoning: {str(e)}"
            self.logger.error(error_msg, exc_info=True)
            
            return ReasoningResult(
                final_answer="",
                reasoning_steps=[],
                success=False,
                scaffold_type="three_stage",
                total_iterations=0,
                processing_time=processing_time,
                error_message=error_msg
            )

    def solve(
        self,
        image_path: str,
        question: str,
        dataset: str = "Unknown",
        sample_id: Optional[str] = None,
        original_index: Optional[str] = None,
        ground_truth: Optional[str] = None,
        generation_kwargs: Optional[Dict[str, Any]] = None,
        vlm_generation_kwargs: Optional[Dict[str, Any]] = None,
        reasoner_generation_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """Solve a visual reasoning problem using the three-stage scaffold."""
        try:
            start_time = time.time()
            session_id = f"three_stage_{sample_id or 'unknown'}_{int(start_time)}"
            
            # Extract clarification experiment flags from kwargs
            deny_clarifications = kwargs.get('deny_clarifications', False)
            track_clarification_requests = kwargs.get('track_clarification_requests', False)
            
            # 🔍 DEBUG: Log clarification experiment configuration
            self.logger.info(f"🔍 DEBUG: === THREE-STAGE SOLVE START ===")
            self.logger.info(f"🔍 DEBUG: Sample ID: {sample_id}")
            self.logger.info(f"🔍 DEBUG: Question: {question[:100]}...")
            self.logger.info(f"🔍 DEBUG: Dataset: {dataset}")
            self.logger.info(f"🔍 DEBUG: deny_clarifications: {deny_clarifications}")
            self.logger.info(f"🔍 DEBUG: track_clarification_requests: {track_clarification_requests}")
            self.logger.info(f"🔍 DEBUG: All kwargs: {list(kwargs.keys())}")
            self.logger.info(f"🔍 DEBUG: Prompt template: {getattr(self, 'prompt_template_name', 'N/A')}")
            
            # Start logging session if logger is available
            if self.unified_logger:
                session = self.unified_logger.start_session(
                    session_id=session_id,
                    sample_id=sample_id or "unknown",
                    question=question,
                    image_path=image_path,
                    reasoning_approach="three_stage",
                    original_index=original_index,
                    dataset=dataset,
                    ground_truth=ground_truth
                )
            
            # Normalize generation kwargs
            if vlm_generation_kwargs is None and reasoner_generation_kwargs is None:
                vlm_generation_kwargs = generation_kwargs or {}
                reasoner_generation_kwargs = generation_kwargs or {}
            else:
                vlm_generation_kwargs = vlm_generation_kwargs or generation_kwargs or {}
                reasoner_generation_kwargs = reasoner_generation_kwargs or generation_kwargs or {}
            
            # Check if prompts are loaded
            if not self.prompts:
                error_msg = "Prompt templates not loaded - cannot proceed with three-stage reasoning"
                self.logger.error(error_msg)
                if self.unified_logger:
                    self.unified_logger.finish_session(
                        final_answer="Error: Prompt templates not loaded",
                        success=False,
                        termination_reason="prompt_error",
                        error=error_msg,
                        session_id=session_id
                    )
                return {
                    'answer': "Error: Prompt templates not loaded",
                    'success': False,
                    'error': error_msg,
                    'total_time': time.time() - start_time,
                    'iterations': 0
                }
            
            # Stage 1: VLM Initial Description
            initial_desc_start = time.time()
            initial_desc_result = self._stage1_initial_description(
                image_path, question, vlm_generation_kwargs
            )
            initial_desc_time = time.time() - initial_desc_start
            
            if self.unified_logger:
                self.unified_logger.log_step(
                    step_name="Stage 1: VLM Initial Description",
                    step_type="vlm",
                    input_data={"image_path": image_path, "question": question},
                    output_data={
                        "description": initial_desc_result.get("description", ""),
                        "confidence": initial_desc_result.get("confidence"),
                        "used_confidence_prompt": initial_desc_result.get("used_confidence_prompt", False)
                    },
                    runtime=initial_desc_time,
                    success=initial_desc_result.get("success", False),
                    error=initial_desc_result.get("error"),
                    metadata={
                        "stage": 1,
                        "prompt": initial_desc_result.get("prompt", ""),
                        "confidence_experiment": {
                            "enable_vlm_confidence": self.enable_vlm_confidence,
                            "use_confidence_in_reasoner": self.use_confidence_in_reasoner
                        }
                    },
                    session_id=session_id
                )
            
            if not initial_desc_result['success']:
                if self.unified_logger:
                    self.unified_logger.finish_session(
                        final_answer=f"Error in initial VLM description: {initial_desc_result.get('error', 'Unknown error')}",
                        success=False,
                        termination_reason="initial_vlm_failed",
                        error=initial_desc_result.get('error'),
                        session_id=session_id
                    )
                return {
                    'answer': f"Error in initial VLM description: {initial_desc_result.get('error', 'Unknown error')}",
                    'success': False,
                    'error': initial_desc_result.get('error'),
                    'total_time': time.time() - start_time,
                    'iterations': 0
                }
            
            initial_description = initial_desc_result['description']
            confidence_score = initial_desc_result.get('confidence')
            
            # Stage 2: Reasoner Decision (adaptive-style)
            decision_start = time.time()
            decision_result = self._stage2_adaptive_decision(
                question, initial_description, reasoner_generation_kwargs, confidence=confidence_score
            )
            decision_time = time.time() - decision_start
            
            if self.unified_logger:
                self.unified_logger.log_step(
                    step_name="Stage 2: Reasoner Decision",
                    step_type="reasoner",
                    input_data={
                        "question": question,
                        "description": initial_description,
                        "confidence": confidence_score
                    },
                    output_data={
                        "status": decision_result.get("status", ""),
                        "answer": decision_result.get("answer", ""),
                        "request": decision_result.get("request", ""),
                        "reasoning": decision_result.get("reasoning", "")
                    },
                    runtime=decision_time,
                    success=decision_result.get("success", False),
                    error=decision_result.get("error"),
                    metadata={"stage": 2, "prompt": decision_result.get("prompt", "")},
                    session_id=session_id
                )
            
            if not decision_result['success']:
                if self.unified_logger:
                    self.unified_logger.finish_session(
                        final_answer=f"Error in decision reasoning: {decision_result.get('error', 'Unknown error')}",
                        success=False,
                        termination_reason="decision_failed",
                        error=decision_result.get('error'),
                        session_id=session_id
                    )
                return {
                    'answer': f"Error in decision reasoning: {decision_result.get('error', 'Unknown error')}",
                    'success': False,
                    'error': decision_result.get('error'),
                    'total_time': time.time() - start_time,
                    'iterations': 1
                }
            
            # Check if we can solve directly
            if decision_result['status'] == 'SOLVED':
                final_answer = decision_result.get('answer', '')
                
                # 🔍 DEBUG: Log when we solve directly (no clarification)
                self.logger.info(f"🔍 DEBUG: DECISION OUTCOME: SOLVED directly (no clarification needed)")
                self.logger.info(f"🔍 DEBUG: Final answer: {repr(final_answer)}")
                self.logger.info(f"🔍 DEBUG: Clarification requested: False")
                
                if self.unified_logger:
                    self.unified_logger.finish_session(
                        final_answer=final_answer,
                        success=True,
                        termination_reason="solved_directly",
                        session_id=session_id
                    )
                
                return {
                    'answer': final_answer,
                    'success': True,
                    'reasoning': decision_result.get('reasoning', ''),
                    'total_time': time.time() - start_time,
                    'iterations': 1,
                    'termination_reason': 'solved_directly',
                    'needs_more_info': False,  # No clarification requested
                    'clarifying_question': '',
                    'clarification_denied': False
                }
            
            # Stage 2.5: VLM Focused Information (if needed)
            elif decision_result['status'] == 'NEED_MORE_INFO':
                info_request = decision_result.get('request', '')
                
                # 🔍 DEBUG: Log when clarification is needed
                self.logger.info(f"🔍 DEBUG: DECISION OUTCOME: NEED_MORE_INFO (clarification requested)")
                self.logger.info(f"🔍 DEBUG: Info request: {repr(info_request)}")
                self.logger.info(f"🔍 DEBUG: Clarification requested: True")
                
                if not info_request or info_request == 'N/A':
                    # If no valid request, provide best guess with available info
                    final_answer = decision_result.get('answer', 'Unable to solve with available information')
                    
                    self.logger.info(f"🔍 DEBUG: No valid request provided, terminating with: {repr(final_answer)}")
                    
                    if self.unified_logger:
                        self.unified_logger.finish_session(
                            final_answer=final_answer,
                            success=False,
                            termination_reason="no_valid_request",
                            session_id=session_id
                        )
                    
                    return {
                        'answer': final_answer,
                        'success': False,
                        'reasoning': decision_result.get('reasoning', ''),
                        'total_time': time.time() - start_time,
                        'iterations': 1,
                        'termination_reason': 'no_valid_request',
                        'needs_more_info': True,  # Clarification was requested
                        'clarifying_question': info_request or '',
                        'clarification_denied': False
                    }
                
                # Get focused information from VLM (or deny if experiment flag is set)
                focused_start = time.time()
                if deny_clarifications:
                    # Deny clarification request - provide a generic "no additional info" response
                    self.logger.info(f"🔍 DEBUG: Clarification request DENIED for experiment: {info_request}")
                    self.logger.info(f"🔍 DEBUG: deny_clarifications flag is True")
                    focused_result = {
                        'success': True,
                        'focused_info': "Additional information is not available at this time. Please reason with the information provided in the initial description.",
                        'prompt': 'N/A (clarification denied)',
                        'denied': True
                    }
                else:
                    self.logger.info(f"🔍 DEBUG: Clarification request ALLOWED - getting focused info")
                    self.logger.info(f"🔍 DEBUG: deny_clarifications flag is False")
                    focused_result = self._stage2_5_focused_info(
                        image_path, question, info_request, initial_description, vlm_generation_kwargs
                    )
                    focused_result['denied'] = False
                focused_time = time.time() - focused_start
                
                if self.unified_logger:
                    self.unified_logger.log_step(
                        step_name="Stage 2.5: VLM Focused Information",
                        step_type="vlm",
                        input_data={
                            "image_path": image_path,
                            "question": question,
                            "info_request": info_request
                        },
                        output_data={"focused_info": focused_result.get("focused_info", "")},
                        runtime=focused_time,
                        success=focused_result.get("success", False),
                        error=focused_result.get("error"),
                        metadata={"stage": 2.5, "prompt": focused_result.get("prompt", "")},
                        session_id=session_id
                    )
                
                if focused_result['success']:
                    # Combine descriptions with request context
                    combined_description = f"{initial_description}\n\nAdditional Information Requested:\n{info_request}\n\nAdditional Information Provided:\n{focused_result['focused_info']}"
                else:
                    # Use original description if focused info failed
                    combined_description = initial_description
                
                # Stage 3: Final Reasoning with combined information
                final_start = time.time()
                final_result = self._stage3_final_reasoning(
                    question, combined_description, reasoner_generation_kwargs
                )
                final_time = time.time() - final_start
                
                if self.unified_logger:
                    self.unified_logger.log_step(
                        step_name="Stage 3: Final Reasoning",
                        step_type="reasoner",
                        input_data={
                            "question": question,
                            "combined_description": combined_description
                        },
                        output_data={
                            "answer": final_result.get("answer", ""),
                            "reasoning": final_result.get("reasoning", "")
                        },
                        runtime=final_time,
                        success=final_result.get("success", False),
                        error=final_result.get("error"),
                        metadata={"stage": 3, "prompt": final_result.get("prompt", "")},
                        session_id=session_id
                    )
                
                if not final_result['success']:
                    if self.unified_logger:
                        self.unified_logger.finish_session(
                            final_answer=f"Error in final reasoning: {final_result.get('error', 'Unknown error')}",
                            success=False,
                            termination_reason="final_reasoning_failed",
                            error=final_result.get('error'),
                            session_id=session_id
                        )
                    return {
                        'answer': f"Error in final reasoning: {final_result.get('error', 'Unknown error')}",
                        'success': False,
                        'error': final_result.get('error'),
                        'total_time': time.time() - start_time,
                        'iterations': 2
                    }
                
                final_answer = final_result.get('answer', '')
                
                if self.unified_logger:
                    self.unified_logger.finish_session(
                        final_answer=final_answer,
                        success=True,
                        termination_reason="solved_with_clarification",
                        session_id=session_id
                    )
                
                return {
                    'answer': final_answer,
                    'success': True,
                    'reasoning': final_result.get('reasoning', ''),
                    'total_time': time.time() - start_time,
                    'iterations': 2,
                    'termination_reason': 'solved_with_clarification',
                    'needs_more_info': True,  # Clarification was requested and processed
                    'clarifying_question': info_request,
                    'clarification_denied': focused_result.get('denied', False)
                }
            
            else:
                # Unknown status
                final_answer = decision_result.get('answer', 'Unable to determine answer')
                
                if self.unified_logger:
                    self.unified_logger.finish_session(
                        final_answer=final_answer,
                        success=False,
                        termination_reason="unknown_status",
                        session_id=session_id
                    )
                
                return {
                    'answer': final_answer,
                    'success': False,
                    'reasoning': decision_result.get('reasoning', ''),
                    'total_time': time.time() - start_time,
                    'iterations': 1,
                    'termination_reason': 'unknown_status',
                    'needs_more_info': False,  # Unknown status, assume no clarification
                    'clarifying_question': '',
                    'clarification_denied': False
                }
            
        except Exception as e:
            error_msg = f"Unexpected error in three-stage reasoning: {str(e)}"
            self.logger.error(error_msg, exc_info=True)
            
            if self.unified_logger:
                self.unified_logger.finish_session(
                    final_answer=f"Error: {error_msg}",
                    success=False,
                    termination_reason="unexpected_error",
                    error=error_msg,
                    session_id=session_id
                )
            
            return {
                'answer': f"Error: {error_msg}",
                'success': False,
                'error': error_msg,
                'total_time': time.time() - start_time,
                'iterations': 0,
                'needs_more_info': False,  # Error case, no clarification
                'clarifying_question': '',
                'clarification_denied': False
            }

    def reason_from_description(
        self,
        description: str,
        question: str,
        dataset: str = "MathDataset",
        prompt_template_name: str = "three_stage_math_v1",
        generation_kwargs: Optional[Dict[str, Any]] = None,
        image_path: Optional[str] = None,
        vlm_generation_kwargs: Optional[Dict[str, Any]] = None,
        deny_clarifications: bool = False,
        track_clarification_requests: bool = False
    ) -> Dict[str, Any]:
        """
        Run three-stage reasoning from a given description (Part 2 for reward function).
        
        This method starts from stage 2 with a provided description and may request
        additional VLM information if needed:
        1. Takes a description as input
        2. Runs stage 2 adaptive decision 
        3. If more info needed, calls VLM for focused description (requires image_path)
        4. Runs final reasoning and extracts answer
        
        Args:
            description: Image description (from external captioner)
            question: Question about the image
            dataset: Dataset type for prompt selection (default: "MathDataset")
            prompt_template_name: Template name to use (default: "three_stage_math_v1")
            generation_kwargs: Generation parameters for reasoner
            image_path: Path to image (required if VLM clarification needed)
            vlm_generation_kwargs: Generation parameters for VLM (if clarification needed)
            
        Returns:
            Dict with 'answer', 'success', 'reasoning', etc.
        """
        try:
            start_time = time.time()
            
            # Load the specified prompt template (math-specific)
            if prompt_template_name != self.prompt_template_name:
                try:
                    temp_prompts = self.prompt_manager.load_prompt_collection(prompt_template_name)
                    self.logger.info(f"Loaded {prompt_template_name} prompt collection for reasoning")
                    # Temporarily use specified prompts
                    original_prompts = self.prompts
                    self.prompts = temp_prompts
                except Exception as e:
                    self.logger.warning(f"Could not load {prompt_template_name} prompts, using default: {e}")
                    original_prompts = None
            else:
                original_prompts = None
            
            # Stage 2: Adaptive Decision (starting from description)
            decision_result = self._stage2_adaptive_decision(
                question=question,
                description=description,
                generation_kwargs=generation_kwargs
            )
            
            if not decision_result['success']:
                if original_prompts:
                    self.prompts = original_prompts
                return {
                    'answer': f"Error in decision reasoning: {decision_result.get('error', 'Unknown error')}",
                    'success': False,
                    'error': decision_result.get('error'),
                    'total_time': time.time() - start_time,
                    'iterations': 1,
                    'reasoning': ''
                }
            
            # Check status and handle accordingly
            if decision_result['status'] == 'SOLVED':
                # Stage 2 decided it can solve directly
                final_answer = decision_result.get('answer', '')
                
                # Restore original prompts if we switched
                if original_prompts:
                    self.prompts = original_prompts
                
                return {
                    'answer': final_answer,
                    'success': True,
                    'reasoning': decision_result.get('reasoning', ''),
                    'total_time': time.time() - start_time,
                    'iterations': 1,
                    'termination_reason': 'solved_directly',
                    'needs_more_info': False  # No clarifying question asked
                }
            
            elif decision_result['status'] == 'NEED_MORE_INFO':
                info_request = decision_result.get('request', '')
                
                if not info_request or info_request == 'N/A' or not image_path:
                    # Cannot get more info in reward function context, use current answer
                    final_answer = decision_result.get('answer', 'Unable to solve with available information')
                    
                    # Restore original prompts if we switched
                    if original_prompts:
                        self.prompts = original_prompts
                    
                    return {
                        'answer': final_answer,
                        'success': True,  # Mark as success even if limited info
                        'reasoning': decision_result.get('reasoning', ''),
                        'total_time': time.time() - start_time,
                        'iterations': 1,
                        'termination_reason': 'limited_info_available',
                        'needs_more_info': True  # Clarifying question was requested but couldn't be fulfilled
                    }
                
                # Get focused information from VLM (if image_path provided and not denied)
                if deny_clarifications:
                    # Deny clarification request - provide a generic "no additional info" response
                    self.logger.info(f"Clarification request denied for experiment: {info_request}")
                    focused_result = {
                        'success': True,
                        'focused_info': "Additional information is not available at this time. Please reason with the information provided in the initial description.",
                        'prompt': 'N/A (clarification denied)',
                        'denied': True
                    }
                else:
                    focused_result = self._stage2_5_focused_info(
                        image_path, question, info_request, description, vlm_generation_kwargs
                    )
                    focused_result['denied'] = False
                
                if focused_result['success']:
                    # Combine descriptions
                    combined_description = f"{description}\n\nAdditional Information:\n{focused_result['focused_info']}"
                else:
                    # Use original description if focused info failed
                    combined_description = description
                
                # Stage 3: Final Reasoning with combined information
                final_result = self._stage3_final_reasoning(
                    question, combined_description, generation_kwargs
                )
                
                if not final_result['success']:
                    if original_prompts:
                        self.prompts = original_prompts
                    return {
                        'answer': f"Error in final reasoning: {final_result.get('error', 'Unknown error')}",
                        'success': False,
                        'error': final_result.get('error'),
                        'total_time': time.time() - start_time,
                        'iterations': 2,
                        'needs_more_info': True  # Clarifying question was asked
                    }
                
                final_answer = final_result.get('answer', '')
                
                # Restore original prompts if we switched
                if original_prompts:
                    self.prompts = original_prompts
                
                return {
                    'answer': final_answer,
                    'success': True,
                    'reasoning': final_result.get('reasoning', ''),
                    'total_time': time.time() - start_time,
                    'iterations': 2,
                    'termination_reason': 'solved_with_clarification',
                    'needs_more_info': True,  # Clarifying question was asked
                    'clarifying_question': info_request,  # Log what question was asked
                    'additional_info': focused_result.get('focused_info', '') if focused_result['success'] else None,
                    'clarification_denied': focused_result.get('denied', False)
                }
            
            else:
                # Unknown status
                final_answer = decision_result.get('answer', 'Unable to determine answer')
                
                # Restore original prompts if we switched
                if original_prompts:
                    self.prompts = original_prompts
                
                return {
                    'answer': final_answer,
                    'success': True,  # Mark as success with best effort
                    'reasoning': decision_result.get('reasoning', ''),
                    'total_time': time.time() - start_time,
                    'iterations': 1,
                    'termination_reason': 'unknown_status',
                    'needs_more_info': False  # No clarifying question in unknown status
                }
            
        except Exception as e:
            # Restore original prompts on error
            if 'original_prompts' in locals() and original_prompts:
                self.prompts = original_prompts
                
            error_msg = f"Unexpected error in three-stage reasoning from description: {str(e)}"
            self.logger.error(error_msg, exc_info=True)
            
            return {
                'answer': f"Error: {error_msg}",
                'success': False,
                'error': error_msg,
                'total_time': time.time() - start_time,
                'iterations': 0
            }

    def _stage1_initial_description(
        self,
        image_path: str,
        question: str,
        generation_kwargs: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """Stage 1: Generate initial image description."""
        try:
            # Choose prompt based on confidence experiment flags
            if self.enable_vlm_confidence:
                # Try to use confidence prompt from template
                if self.prompts and 'vlm_confidence_prompt' in self.prompts:
                    prompt_key = 'vlm_confidence_prompt'
                    use_confidence_prompt = True
                else:
                    self.logger.warning("VLM confidence prompt not found, falling back to regular prompt")
                    prompt_key = 'vlm_initial_description_prompt'
                    use_confidence_prompt = False
            else:
                prompt_key = 'vlm_initial_description_prompt'
                use_confidence_prompt = False
            
            if not self.prompts or prompt_key not in self.prompts:
                raise ValueError(
                    f"VLM prompt template '{prompt_key}' not loaded. "
                    "Cannot proceed without proper prompt template. "
                    "Check that template file exists and contains required prompt."
                )
            
            prompt = self.prompts[prompt_key]
            
            # Format the prompt with the question (escape {} characters for math prompts)
            prompt_template = prompt.replace("boxed{}", "boxed{{}}")
            formatted_prompt = prompt_template.format(question=question)
            
            # Generate response
            result = self.vlm.generate(
                image=image_path,
                prompt=formatted_prompt,
                **(generation_kwargs or {})
            )
            
            # Parse confidence if using confidence prompt
            confidence_score = None
            description = result
            
            if self.enable_vlm_confidence and use_confidence_prompt:
                try:
                    description, confidence_score = self._parse_vlm_confidence_response(result)
                except Exception as e:
                    self.logger.warning(f"Error parsing confidence from VLM response: {e}")
                    # Safe fallback: use entire response as description
                    description = result
                    confidence_score = None
            
            return {
                'success': True,
                'description': description,
                'confidence': confidence_score,
                'raw_vlm_response': result,
                'prompt': formatted_prompt,
                'used_confidence_prompt': use_confidence_prompt
            }
            
        except Exception as e:
            self.logger.error(f"Error in stage 1 description: {e}", exc_info=True)
            return {
                'success': False,
                'error': str(e),
                'description': '',
                'confidence': None,
                'raw_vlm_response': '',
                'prompt': '',
                'used_confidence_prompt': False
            }

    def _stage2_adaptive_decision(
        self,
        question: str,
        description: str,
        generation_kwargs: Optional[Dict[str, Any]] = None,
        confidence: Optional[int] = None
    ) -> Dict[str, Any]:
        """Stage 2: Adaptive decision using unified prompt (similar to adaptive scaffold)."""
        try:
            if not self.prompts or 'reasoner_adaptive_decision_prompt' not in self.prompts:
                raise ValueError(
                    "Adaptive decision prompt template not loaded. "
                    "Cannot proceed without proper prompt template. "
                    "Check that template file exists and contains 'reasoner_adaptive_decision_prompt'."
                )
            
            prompt = self.prompts['reasoner_adaptive_decision_prompt']
            
            # Format the prompt with variables (escape {} characters for math prompts)
            escaped_template = prompt.replace("boxed{}", "boxed{{}}")
            formatted_prompt = escaped_template.format(
                description=description,
                question=question,
                answer='{answer}'
            )
            
            # 🔍 DEBUG: Log the full prompt being sent to reasoner
            self.logger.info("🔍 DEBUG: Stage 2 Adaptive Decision")
            self.logger.info(f"🔍 DEBUG: Question: {question[:100]}...")
            self.logger.info(f"🔍 DEBUG: Description length: {len(description)} chars")
            self.logger.info(f"🔍 DEBUG: Generation kwargs: {generation_kwargs}")
            self.logger.info(f"🔍 DEBUG: Formatted prompt (first 500 chars): {formatted_prompt[:500]}...")
            
            result = self.reasoner.reason(
                context=formatted_prompt,
                **(generation_kwargs or {})
            )
            
            # 🔍 DEBUG: Log the raw reasoner response
            self.logger.info(f"🔍 DEBUG: Raw reasoner response: {repr(result)}")
            self.logger.info(f"🔍 DEBUG: Response type: {type(result)}")
            self.logger.info(f"🔍 DEBUG: Response length: {len(str(result))} chars")
            
            # Parse the result using adaptive-style parsing
            parsed_result = self._parse_adaptive_decision_response(result)
            
            # 🔍 DEBUG: Log the parsed result
            self.logger.info(f"🔍 DEBUG: Parsed status: {parsed_result.get('status', 'N/A')}")
            self.logger.info(f"🔍 DEBUG: Parsed answer: {repr(parsed_result.get('answer', 'N/A'))}")
            self.logger.info(f"🔍 DEBUG: Parsed request: {repr(parsed_result.get('request', 'N/A'))}")
            self.logger.info(f"🔍 DEBUG: Parsed reasoning (first 200 chars): {str(parsed_result.get('reasoning', ''))[:200]}...")
            
            # Return with prompt for debugging
            parsed_result['prompt'] = formatted_prompt
            parsed_result['success'] = True
            
            return parsed_result
            
        except Exception as e:
            self.logger.error(f"Error in stage 2 adaptive decision: {e}", exc_info=True)
            return {
                'success': False,
                'error': str(e),
                'status': 'ERROR',
                'answer': '',
                'request': '',
                'reasoning': '',
                'prompt': ''
            }

    def _parse_adaptive_decision_response(self, result: str) -> Dict[str, Any]:
        """Parse adaptive decision response using format similar to adaptive scaffold."""
        import re
        
        # 🔍 DEBUG: Log parsing details
        self.logger.info("🔍 DEBUG: Parsing adaptive decision response")
        self.logger.info(f"🔍 DEBUG: Input result (first 300 chars): {result[:300]}...")
        
        # Parse using adaptive-style format
        reasoning = re.search(r"Reasoning:\s*(.*?)(?=\n(?:\*\*)?Status:|$)", result, re.DOTALL | re.IGNORECASE)
        status_val = re.search(r"(?:\*\*)?Status(?:\*\*)?:\s*(SOLVED|NEED_MORE_INFO)", result, re.IGNORECASE)
        answer_val = re.search(r"(?:\*\*)?Answer(?:\*\*)?:\s*(.*?)(?=\n(?:\*\*)?Request:|$)", result, re.DOTALL | re.IGNORECASE)
        request_val = re.search(r"(?:\*\*)?Request(?:\*\*)?:\s*(.*)", result, re.DOTALL | re.IGNORECASE)

        # 🔍 DEBUG: Log regex matches
        self.logger.info(f"🔍 DEBUG: Reasoning match found: {reasoning is not None}")
        self.logger.info(f"🔍 DEBUG: Status match found: {status_val is not None}")
        if status_val:
            self.logger.info(f"🔍 DEBUG: Status match content: {repr(status_val.group(1))}")
        self.logger.info(f"🔍 DEBUG: Answer match found: {answer_val is not None}")
        if answer_val:
            self.logger.info(f"🔍 DEBUG: Answer match content: {repr(answer_val.group(1)[:100])}...")
        self.logger.info(f"🔍 DEBUG: Request match found: {request_val is not None}")
        if request_val:
            self.logger.info(f"🔍 DEBUG: Request match content: {repr(request_val.group(1)[:100])}...")

        parsed_reasoning = reasoning.group(1).strip() if reasoning else result
        
        # Determine status
        if status_val:
            parsed_status = self._parse_status(status_val.group(1))
            self.logger.info(f"🔍 DEBUG: Status from explicit Status field: {parsed_status}")
        else:
            # If no explicit status, infer from content
            has_answer = answer_val and answer_val.group(1).strip()
            parsed_status = 'SOLVED' if has_answer else 'NEED_MORE_INFO'
            self.logger.info(f"🔍 DEBUG: Status inferred from content (has_answer={has_answer}): {parsed_status}")
        
        def _clean_field(s: str) -> str:
            return s.strip().strip('*').strip().upper()

        # Parse answer
        parsed_answer_str = answer_val.group(1).strip() if answer_val else "N/A"
        if _clean_field(parsed_answer_str) == "N/A":
            parsed_answer_str = "N/A"
        parsed_answer = parsed_answer_str if parsed_status == 'SOLVED' and parsed_answer_str != "N/A" else None

        # Parse request
        parsed_request_str = request_val.group(1).strip() if request_val else "N/A"
        if _clean_field(parsed_request_str) == "N/A":
            parsed_request_str = "N/A"
        parsed_request = parsed_request_str if parsed_status == 'NEED_MORE_INFO' and parsed_request_str != "N/A" else None

        # If status is SOLVED but answer is still None, try to extract from reasoning
        if parsed_status == 'SOLVED' and not parsed_answer:
            extracted_answer = self._extract_answer_from_reasoning(parsed_reasoning)
            if extracted_answer.upper() != "N/A" and extracted_answer != "Unable to determine answer":
                parsed_answer = extracted_answer
                self.logger.info(f"🔍 DEBUG: Extracted answer from reasoning: {repr(extracted_answer)}")

        # If status is NEED_MORE_INFO but request is None, provide default
        if parsed_status == 'NEED_MORE_INFO' and not parsed_request:
            parsed_request = "Please provide more specific details about the visual elements in the image."
            self.logger.info("🔍 DEBUG: Using default request for NEED_MORE_INFO")
        
        # 🔍 DEBUG: Log final parsed result
        self.logger.info(f"🔍 DEBUG: Final parsed status: {parsed_status}")
        self.logger.info(f"🔍 DEBUG: Final parsed answer: {repr(parsed_answer)}")
        self.logger.info(f"🔍 DEBUG: Final parsed request: {repr(parsed_request)}")
            
        return {
            'reasoning': parsed_reasoning,
            'status': parsed_status,
            'answer': parsed_answer,
            'request': parsed_request
        }

    def _parse_status(self, status_str: str) -> str:
        """Parse and validate status string."""
        status = str(status_str).upper().strip()
        if 'SOLVED' in status:
            return 'SOLVED'
        elif 'NEED_MORE_INFO' in status or 'NEED MORE INFO' in status:
            return 'NEED_MORE_INFO'
        else:
            # Default to NEED_MORE_INFO if unclear
            return 'NEED_MORE_INFO'

    def _extract_answer_from_reasoning(self, reasoning: str) -> str:
        """Extract answer from reasoning text if not explicitly provided."""
        import re
        
        # Look for common answer patterns
        patterns = [
            r'(?:answer|solution|result)(?:\s*is)?:?\s*(.+?)(?:\n|$)',
            r'(?:therefore|thus|so),?\s*(.+?)(?:\n|$)',
            r'(?:final answer|conclusion):?\s*(.+?)(?:\n|$)'
        ]
        
        for pattern in patterns:
            match = re.search(pattern, reasoning, re.IGNORECASE)
            if match:
                extracted = match.group(1).strip()
                if "unable to determine" not in extracted.lower() and "cannot determine" not in extracted.lower():
                    return extracted
        
        # If no pattern matches, extract the last meaningful sentence
        sentences = [s.strip() for s in reasoning.split('.') if s.strip()]
        if sentences:
            last_sentence = sentences[-1]
            if "unable to determine" not in last_sentence.lower() and "cannot determine" not in last_sentence.lower():
                return last_sentence
        
        # Final fallback
        return "Based on the available information, more details are needed for a definitive answer"

    def _stage2_5_focused_info(
        self,
        image_path: str,
        question: str,
        info_request: str,
        previous_description: str,
        generation_kwargs: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """Stage 2.5: Get focused information from VLM based on reasoner request."""
        try:
            if not self.prompts or 'vlm_focused_description_prompt' not in self.prompts:
                raise ValueError(
                    "Focused description prompt template not loaded. "
                    "Cannot proceed without proper prompt template. "
                    "Check that template file exists and contains 'vlm_focused_description_prompt'."
                )
            
            prompt = self.prompts['vlm_focused_description_prompt']
            
            formatted_prompt = prompt.format(
                question=question,
                focus_request=info_request,
                previous_descriptions=previous_description
            )
            
            result = self.vlm.generate(
                image=image_path,
                prompt=formatted_prompt,
                **(generation_kwargs or {})
            )
        
            return {
                'success': True,
                'focused_info': result,
                'prompt': formatted_prompt
            }
            
        except Exception as e:
            self.logger.error(f"Error getting focused information: {e}", exc_info=True)
            return {
                'success': False,
                'error': str(e),
                'focused_info': '',
                'prompt': ''
            }

    def _stage3_final_reasoning(
        self,
        question: str,
        combined_description: str,
        generation_kwargs: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """Stage 3: Final reasoning with combined information."""
        try:
            if not self.prompts or 'reasoner_final_prompt' not in self.prompts:
                raise ValueError(
                    "Final reasoning prompt template not loaded. "
                    "Cannot proceed without proper prompt template. "
                    "Check that template file exists and contains 'reasoner_final_prompt'."
                )
            
            prompt = self.prompts['reasoner_final_prompt']
            
            # Format the prompt with variables (escape {} characters for math prompts)
            escaped_template = prompt.replace("boxed{}", "boxed{{}}")
            formatted_prompt = escaped_template.format(
                description=combined_description,
                question=question,
                answer='{answer}',
                letter='{letter}'
            )
            
            result = self.reasoner.reason(
                context=formatted_prompt,
                **(generation_kwargs or {})
            )
            
            # Extract answer from result (similar to two-stage parsing)
            reasoning_content, final_answer = self._parse_reasoner_response(result)
            
            return {
                'success': True,
                'answer': final_answer,
                'reasoning': reasoning_content,
                'raw_output': result,
                'prompt': formatted_prompt
            }
            
        except Exception as e:
            self.logger.error(f"Error in stage 3 final reasoning: {e}", exc_info=True)
            return {
                'success': False,
                'error': str(e),
                'answer': '',
                'reasoning': '',
                'raw_output': '',
                'prompt': ''
            }

    def _parse_reasoner_response(self, reasoner_response: str) -> tuple[str, str]:
        """Parse reasoner response to extract reasoning content and final answer."""
        import re
        
        BOXED_PREFIX_RE = re.compile(r"\\{1,4}boxed\{")

        def extract_boxed(text: str) -> str | None:
            m = BOXED_PREFIX_RE.search(text)
            if not m:
                return None
            i = m.end()
            depth = 1
            start = i
            while i < len(text) and depth:
                if text[i] == '{':
                    depth += 1
                elif text[i] == '}':
                    depth -= 1
                    if depth == 0:
                        return text[start:i].strip()
                i += 1
            return None
        
        # Extract reasoning content
        reasoning_content = "[No Reasoning Available]"
        
        # Method 1: Parse <think> tags from content
        think_match = re.search(r"<think>(.*?)</think>", reasoner_response, re.DOTALL | re.IGNORECASE)
        if think_match:
            reasoning_content = think_match.group(1).strip()
        
        # Extract final answer from between <answer> tags, if present
        final_answer_text = reasoner_response
        answer_match = re.search(r"<answer>(.*?)</answer>", reasoner_response, re.DOTALL | re.IGNORECASE)
        if answer_match:
            answer_content = answer_match.group(1).strip()
            # Further parse "Final Answer: " from within <answer> if present
            final_answer_match = re.search(r"Final Answer:(.*)", answer_content, re.IGNORECASE | re.DOTALL)
            if final_answer_match:
                final_answer_text = final_answer_match.group(1).strip()
            else:
                boxed_match = extract_boxed(answer_content)
                if boxed_match:
                    final_answer_text = boxed_match
                else:
                    # If no "Final Answer:" prefix, use entire <answer> content
                    final_answer_text = answer_content

        return reasoning_content, final_answer_text

    def _parse_vlm_confidence_response(self, vlm_response: str) -> tuple[str, Optional[int]]:
        """Parse VLM response to extract description and confidence score."""
        import re
        
        try:
            # Look for the structured format: DESCRIPTION: ... CONFIDENCE: ...
            desc_match = re.search(r"DESCRIPTION:\s*(.*?)(?=CONFIDENCE:|$)", vlm_response, re.DOTALL | re.IGNORECASE)
            conf_match = re.search(r"CONFIDENCE:\s*(\d+)", vlm_response, re.IGNORECASE)
            
            if desc_match and conf_match:
                description = desc_match.group(1).strip()
                confidence = int(conf_match.group(1))
                
                # Validate confidence range
                if 0 <= confidence <= 100:
                    return description, confidence
                else:
                    self.logger.warning(f"Confidence score out of range (0-100): {confidence}")
                    return description, None
            else:
                # Fallback: if structured format not found, use entire response as description
                return vlm_response.strip(), None
                
        except Exception as e:
            self.logger.warning(f"Error parsing VLM confidence response: {e}")
            # Safe fallback: return entire response as description with no confidence
            return vlm_response.strip(), None 