import requests
import time
from typing import List, Any, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed


def correctness_reward(*, completions: List[str], ground_truth: List[str], **kwargs) -> List[float]:
    """Binary reward: 1 if completion matches ground truth exactly, else 0.

    Args:
        completions: List of generated strings (REQUIRED).
        ground_truth: List of reference answers (REQUIRED).
        **kwargs: Unused extra columns from dataset.
    Returns:
        List of float rewards aligned with completions.
    """
    # Validate required parameters - fail fast
    if not completions:
        raise ValueError("completions list cannot be empty")
    if not ground_truth:
        raise ValueError("ground_truth list cannot be empty")
    if len(completions) != len(ground_truth):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != ground_truth({len(ground_truth)})")

    rewards: List[float] = []
    for comp, gt in zip(completions, ground_truth):
        if comp is None:
            raise ValueError("Found None completion - all completions must be strings")
        if gt is None:
            raise ValueError("Found None ground truth - all ground truth values must be strings")
        
        # Simple normalization
        if str(comp).strip().lower() == str(gt).strip().lower():
            rewards.append(1.0)
        else:
            rewards.append(0.0)
    return rewards


def length_reward(completions: List[str], **kwargs) -> List[float]:
    """Reward longer completions by token/char length (toy example)."""
    return [float(len(c)) for c in completions]


def pipeline_correctness_reward(
    completions: List[str], 
    ground_truth: List[str], 
    questions: List[str],
    reasoner_server_url: str,
    reasoner_model_name: str,
    # Reasoner generation parameters (match trajectory generation)
    reasoner_max_tokens: int,
    reasoner_temperature: float,
    reasoner_top_p: float,
    reasoner_top_k: int,
    scaffold_max_iterations: int,
    **kwargs
) -> List[float]:
    """
    Full pipeline reward using split adaptive scaffold.
    
    This function takes captioner outputs (image descriptions) and runs them through
    the adaptive reasoning scaffold to get final answers, then compares against ground truth.
    
    Args:
        completions: List of captioner descriptions (not final answers)
        ground_truth: List of correct final answers (REQUIRED)
        questions: List of questions for each sample (REQUIRED)
        reasoner_server_url: URL of the reasoner server (REQUIRED)
        reasoner_model_name: Name of the reasoner model on the server (REQUIRED)
        reasoner_max_tokens: Max tokens for reasoner generation (REQUIRED)
        reasoner_temperature: Temperature for reasoner generation (REQUIRED)
        reasoner_top_p: Top-p for reasoner generation (REQUIRED)
        reasoner_top_k: Top-k for reasoner generation (REQUIRED, -1 to disable)
        scaffold_max_iterations: Max iterations for adaptive reasoning (REQUIRED)
        **kwargs: Additional data from dataset
        
    Returns:
        List of rewards based on final pipeline answers
    """
    # Validate required parameters - fail fast
    if not completions:
        raise ValueError("completions list cannot be empty")
    if not ground_truth:
        raise ValueError("ground_truth list cannot be empty")
    if not questions:
        raise ValueError("questions list cannot be empty")
    if len(completions) != len(ground_truth):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != ground_truth({len(ground_truth)})")
    if len(completions) != len(questions):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != questions({len(questions)})")
    
    # Create adaptive scaffold for reasoning with proper generation parameters
    scaffold = _get_adaptive_scaffold(
        reasoner_server_url=reasoner_server_url,
        reasoner_model_name=reasoner_model_name,
        max_iterations=scaffold_max_iterations,
        reasoner_max_tokens=reasoner_max_tokens,
        reasoner_temperature=reasoner_temperature,
        reasoner_top_p=reasoner_top_p,
        reasoner_top_k=reasoner_top_k
    )
    
    rewards = []
    
    for i, (description, gt, question) in enumerate(zip(completions, ground_truth, questions)):
        try:
            # Run description through adaptive reasoning to get final answer
            # Note: generation_kwargs here are for the reasoner (already configured in scaffold)
            reasoning_result = scaffold.reason_from_description(
                description=description,
                question=question,
                generation_kwargs=None  # Use scaffold's configured parameters
            )
            
            # Fail fast if reasoning result doesn't have required fields
            if "answer" not in reasoning_result:
                raise ValueError(f"Reasoning result missing 'answer' field: {reasoning_result.keys()}")
            
            final_answer = reasoning_result["answer"]
            
            # Compare final answer with ground truth
            reward = _compute_correctness_reward(final_answer, gt)
            rewards.append(reward)
            
            if i < 3:  # Log first few for debugging
                print(f"  Sample {i}: Description → Adaptive Reasoning → Reward")
                print(f"    Desc: {description[:100]}...")
                print(f"    Final: {final_answer}")
                print(f"    GT: {gt}")
                print(f"    Reward: {reward}")
                print(f"    Success: {reasoning_result['success']}")
                print(f"    Iterations: {reasoning_result['iterations']}")
            
        except Exception as e:
            print(f"Error in pipeline reward for sample {i}: {e}")
            raise  # Fail fast - don't hide errors with 0.0 rewards
    
    return rewards


def math_correctness_reward(completions: List[str], ground_truth: List[str], **kwargs) -> List[float]:
    """Reward = 1 if completion is mathematically equivalent to ground_truth via math_verify.verify.

    Both completion and ground_truth are parsed with math_verify.parse for robust comparison
    (supports LaTeX style, integers, fractions, etc.).
    
    Args:
        completions: List of generated mathematical answers (REQUIRED).
        ground_truth: List of correct mathematical answers (REQUIRED).
        **kwargs: Unused extra columns from dataset.
    Returns:
        List of float rewards aligned with completions.
    """
    from math_verify import parse, verify  # Local util imported lazily to avoid heavy cost when not used

    # Validate required parameters - fail fast
    if not completions:
        raise ValueError("completions list cannot be empty")
    if not ground_truth:
        raise ValueError("ground_truth list cannot be empty")
    if len(completions) != len(ground_truth):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != ground_truth({len(ground_truth)})")

    rewards: List[float] = []
    for comp, gt in zip(completions, ground_truth):
        if comp is None:
            raise ValueError("Found None completion - all completions must be strings")
        if gt is None:
            raise ValueError("Found None ground truth - all ground truth values must be strings")
            
        try:
            parsed_pred = parse(str(comp))
            parsed_gt = parse(str(gt))
            is_correct = verify(parsed_gt, parsed_pred)
            rewards.append(1.0 if is_correct else 0.0)
        except Exception as e:
            # Fail fast - don't hide parsing errors with fallbacks
            raise ValueError(f"Math verification failed for comp='{comp}', gt='{gt}': {e}")
    return rewards


def pipeline_math_correctness_reward(
    # GRPO trainer passes these as positional/keyword args
    completions: List[str], 
    ground_truth: List[str] = None,
    questions: List[str] = None,
    prompts: List[str] = None,  # GRPO trainer passes this
    # VLM server parameters (injected by TRL adapter)
    vlm_server_url: str = None,
    vlm_model_name: str = None,
    vlm_max_tokens: int = None,
    vlm_temperature: float = None,
    vlm_top_p: float = None,
    vlm_top_k: int = None,
    # Reasoner generation parameters (injected by TRL adapter)
    reasoner_server_url: str = None,
    reasoner_model_name: str = None,
    reasoner_max_tokens: int = None,
    reasoner_temperature: float = None,
    reasoner_top_p: float = None,
    reasoner_top_k: int = None,
    scaffold_max_iterations: int = None,
    # Template consistency (should match training config)
    prompt_template_name: str = "adaptive_math_v1",
    **kwargs
) -> List[float]:
    """
    Full pipeline reward with math verification using split adaptive scaffold.
    
    Captioner descriptions → Adaptive reasoning → Final answers → Math verification reward
    
    Args:
        completions: List of captioner descriptions (REQUIRED, from GRPO trainer)
        ground_truth: List of correct mathematical answers (REQUIRED, from dataset)
        questions: List of questions for each sample (can be None, extracted from prompts)
        prompts: List of prompt messages (from GRPO trainer, used if questions is None)
        reasoner_server_url: URL of the reasoner server (injected by TRL adapter)
        reasoner_model_name: Name of the reasoner model (injected by TRL adapter)
        reasoner_max_tokens: Max tokens for reasoner (injected by TRL adapter)
        reasoner_temperature: Temperature for reasoner (injected by TRL adapter)
        reasoner_top_p: Top-p for reasoner (injected by TRL adapter)
        reasoner_top_k: Top-k for reasoner (injected by TRL adapter)
        scaffold_max_iterations: Max iterations for adaptive reasoning (injected by TRL adapter)
        prompt_template_name: Template name for scaffold (injected by TRL adapter)
        **kwargs: Additional data from dataset
        
    Returns:
        List of rewards based on mathematical correctness
    """
    # Validate required parameters - fail fast
    if not completions:
        raise ValueError("completions list cannot be empty")
    
    # Handle ground_truth from kwargs (GRPO trainer passes extra dataset fields this way)
    if ground_truth is None:
        ground_truth = kwargs.get('ground_truth')
    if not ground_truth:
        raise ValueError("ground_truth list cannot be empty - ensure dataset has 'ground_truth' field")
    
    # Extract questions from prompts if not provided directly
    if kwargs.get('question') is not None:
        questions = kwargs.get('question')
    elif kwargs.get('questions') is not None:
        questions = kwargs.get('questions')

    if questions is None:
        if prompts is None:
            raise ValueError("Either 'questions' or 'prompts' must be provided")
        
        # Extract questions from prompt structure
        extracted_questions = []
        for prompt in prompts:
            try:
                # Handle different prompt formats
                if isinstance(prompt, list):
                    # List of messages format - find user message with question
                    user_message = None
                    for msg in prompt:
                        if isinstance(msg, dict) and msg.get("role") == "user":
                            user_message = msg
                            break
                    
                    if user_message:
                        content = user_message.get("content", "")
                        if isinstance(content, list):
                            # Structured content - extract text parts
                            text_parts = []
                            for item in content:
                                if isinstance(item, dict) and item.get("type") == "text":
                                    text_parts.append(item.get("text", ""))
                            question = "\n".join(text_parts)
                        else:
                            question = str(content)
                    else:
                        question = str(prompt)  # Fallback
                else:
                    # String format
                    question = str(prompt)
                
                extracted_questions.append(question)
                
            except Exception as e:
                print(f"Warning: Failed to extract question from prompt {prompt}: {e}")
                extracted_questions.append(str(prompt))  # Fallback to prompt string
        
        questions = extracted_questions
    
    
    completions = list(map(lambda x: x[0]['content'], completions))
    # Validate parameter compatibility
    if len(completions) != len(ground_truth):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != ground_truth({len(ground_truth)})")
    if len(completions) != len(questions):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != questions({len(questions)})")
    
    # Validate injected parameters from TRL adapter
    required_vlm_params = [
        'vlm_server_url', 'vlm_model_name', 'vlm_max_tokens',
        'vlm_temperature', 'vlm_top_p', 'vlm_top_k'
    ]
    required_reasoner_params = [
        'reasoner_server_url', 'reasoner_model_name', 'reasoner_max_tokens',
        'reasoner_temperature', 'reasoner_top_p', 'reasoner_top_k', 'scaffold_max_iterations'
    ]
    
    missing_vlm_params = [p for p in required_vlm_params if locals()[p] is None]
    missing_reasoner_params = [p for p in required_reasoner_params if locals()[p] is None]
    
    if missing_vlm_params:
        raise ValueError(
            f"Missing required VLM parameters: {missing_vlm_params}. "
            f"These should be injected by TRL adapter from config."
        )
    
    if missing_reasoner_params:
        raise ValueError(
            f"Missing required reasoner parameters: {missing_reasoner_params}. "
            f"These should be injected by TRL adapter from config."
        )
    
    # Create adaptive scaffold for reasoning with proper generation parameters
    scaffold = _get_adaptive_scaffold(
        vlm_server_url=vlm_server_url,
        vlm_model_name=vlm_model_name,
        vlm_max_tokens=vlm_max_tokens,
        vlm_temperature=vlm_temperature,
        vlm_top_p=vlm_top_p,
        vlm_top_k=vlm_top_k,
        vlm_prompt_template_name=prompt_template_name,
        reasoner_server_url=reasoner_server_url,
        reasoner_model_name=reasoner_model_name,
        max_iterations=scaffold_max_iterations,
        reasoner_max_tokens=reasoner_max_tokens,
        reasoner_temperature=reasoner_temperature,
        reasoner_top_p=reasoner_top_p,
        reasoner_top_k=reasoner_top_k,
        prompt_template_name=prompt_template_name  # Pass template name for consistency
    )
    
    # Parallel processing configuration - use config values if available
    max_workers = kwargs.get('reward_parallel_workers', 32)
    # max_workers = 32
    enable_parallel = kwargs.get('reward_enable_parallel', True)
    
    if not enable_parallel:
        # Sequential processing fallback
        print(f"🔄 Processing {len(completions)} completions sequentially (parallel disabled)")
        max_workers = 1
    else:
        # Cap at reasonable limits and scale with completion count
        max_workers = min(len(completions), max_workers)
        print(f"🚀 Processing {len(completions)} completions with {max_workers} workers (servers support 64 concurrent connections)")
    
    def process_single_completion(args_tuple):
        """Process a single completion through the adaptive scaffold."""
        i, (description, gt, question) = args_tuple
        try:
            # Each thread gets its own scaffold instance to avoid conflicts
            # Note: scaffold instances are lightweight and thread-safe for server-based reasoners
            thread_scaffold = _get_adaptive_scaffold(
                vlm_server_url=vlm_server_url,
                vlm_model_name=vlm_model_name,
                vlm_max_tokens=vlm_max_tokens,
                vlm_temperature=vlm_temperature,
                vlm_top_p=vlm_top_p,
                vlm_top_k=vlm_top_k,
                vlm_prompt_template_name=prompt_template_name,
                reasoner_server_url=reasoner_server_url,
                reasoner_model_name=reasoner_model_name,
                max_iterations=scaffold_max_iterations,
                reasoner_max_tokens=reasoner_max_tokens,
                reasoner_temperature=reasoner_temperature,
                reasoner_top_p=reasoner_top_p,
                reasoner_top_k=reasoner_top_k,
                prompt_template_name=prompt_template_name
            )
            
            # Run description through adaptive reasoning to get final answer
            reasoning_result = thread_scaffold.reason_from_description(
                description=description,
                question=question,
                generation_kwargs=None  # Use scaffold's configured parameters
            )
            
            final_answer = reasoning_result["answer"]
            
            # Return reasoning result without math verification (do that in main thread)
            return i, reasoning_result, final_answer, gt, question
            
        except Exception as e:
            print(f"Error in math pipeline reward for sample {i}: {e}")
            raise  # Fail fast - don't hide errors with 0.0 rewards
    
    # Prepare arguments for parallel processing
    completion_args = list(enumerate(zip(completions, ground_truth, questions)))
    
    # Process completions in parallel (reasoning only)
    print(f"🚀 Processing {len(completions)} completions with {max_workers} workers...")
    
    reasoning_results = [None] * len(completions)  # Pre-allocate to maintain order
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_args = {
            executor.submit(process_single_completion, args): args 
            for args in completion_args
        }
        
        # Collect reasoning results as they complete
        completed_count = 0
        for future in as_completed(future_to_args):
            try:
                i, reasoning_result, final_answer, gt, question = future.result()
                reasoning_results[i] = (reasoning_result, final_answer, gt, question)  # Store all data
                completed_count += 1
                
                # Progress update
                if completed_count % 4 == 0 or completed_count == len(completions):
                    print(f"    Completed {completed_count}/{len(completions)} reasoning tasks...")
                    
            except Exception as e:
                print(f"Error processing completion: {e}")
                raise  # Fail fast
    
    print(f"✅ Completed all {len(completions)} reasoning tasks")
    
    # Now do math verification sequentially in main thread (very fast)
    print(f"🔢 Computing math verification rewards sequentially...")
    rewards = []
    
    for i, (reasoning_result, final_answer, gt, question) in enumerate(reasoning_results):
        try:
            # Use math verification in main thread (no signal issues)
            reward = _compute_math_reward(final_answer, gt)
            rewards.append(reward)
            
            # Debug logging for first few samples
            if i < 3:
                print(f"  Math Sample {i}: {final_answer} vs {gt} → {reward}")
                print(f"    Question: {question[:100]}...")
                print(f"    Success: {reasoning_result['success']}")
                print(f"    Iterations: {reasoning_result['iterations']}")
                print(f"    Reasoning: {reasoning_result['reasoning'][:200]}...")
                
        except Exception as e:
            print(f"Error computing math reward for sample {i}: {e}")
            raise  # Fail fast
    
    print(f"✅ Completed all {len(completions)} math verification rewards")
    return rewards


def _get_adaptive_scaffold(
    vlm_server_url: str,
    vlm_model_name: str,
    vlm_max_tokens: int,
    vlm_temperature: float,
    vlm_top_p: float,
    vlm_top_k: int,
    vlm_prompt_template_name: str = "adaptive_math_v1",
    reasoner_server_url: str = None,
    reasoner_model_name: str = None,
    max_iterations: int = 5,
    reasoner_max_tokens: int = 100000,
    reasoner_temperature: float = 0.6,
    reasoner_top_p: float = 0.95,
    reasoner_top_k: int = -1,
    prompt_template_name: str = "adaptive_math_v1"
):
    """
    Create adaptive scaffold with server-based reasoner for reward function.
    
    Args:
        reasoner_server_url: URL of the reasoner server (REQUIRED)
        reasoner_model_name: Name of the reasoner model on the server (REQUIRED)
        max_iterations: Maximum adaptive reasoning iterations
        reasoner_max_tokens: Max tokens for reasoner
        reasoner_temperature: Temperature for reasoner  
        reasoner_top_p: Top-p for reasoner
        reasoner_top_k: Top-k for reasoner (-1 to disable)
        prompt_template_name: Template name to use (should match training)
        
    Returns:
        AdaptiveScaffold instance configured for reasoning-only mode
    """
    from ..scaffolds.adaptive import AdaptiveScaffold
    from ..core.reasoner_interface import ReasonerInterface, ReasonerConfig
    from ..core.vlm_interface import VLMInterface, VLMConfig
    
    # Validate required parameters
    if not reasoner_server_url:
        raise ValueError("reasoner_server_url cannot be empty")
    if not reasoner_model_name:
        raise ValueError("reasoner_model_name cannot be empty")
    
    # Handle top_k parameter (convert -1 to None)
    top_k = None if reasoner_top_k <= 0 else reasoner_top_k
    
    # Create reasoner configuration for the server (match trajectory generation params)
    reasoner_config = ReasonerConfig(
        model_name=reasoner_model_name,  # Use actual model name from server
        model_type="openai",  # Generic API type
        api_base=reasoner_server_url,
        api_key="EMPTY",  # Not needed for local server
        timeout=300,  # 10 minute timeout for complex math reasoning
        max_tokens=reasoner_max_tokens,
        temperature=reasoner_temperature,
        top_p=reasoner_top_p,
        top_k=top_k
    )
    
    vlm_config = VLMConfig(
        model_name=vlm_model_name,
        model_type="openai",
        api_base=vlm_server_url,
        api_key="EMPTY",
        timeout=300,  # 10 minute timeout for complex reasoning
        max_tokens=vlm_max_tokens,
        temperature=vlm_temperature,
        top_p=vlm_top_p,
        top_k=vlm_top_k
    )
    
    vlm = VLMInterface.create(vlm_config)
    
    # Create reasoner interface using the factory method
    reasoner = ReasonerInterface.create(reasoner_config)
    
    # Create scaffold with dummy VLM (won't be used in reason_from_description)
    scaffold = AdaptiveScaffold(
        vlm=vlm,  # Not used in reasoning-only mode
        reasoner=reasoner,
        max_iterations=max_iterations,
        prompt_template_name=prompt_template_name  # Use consistent template
    )
    
    return scaffold

def _compute_correctness_reward(final_answer: str, ground_truth: str) -> float:
    """Simple correctness reward computation."""
    if not final_answer or not ground_truth:
        return 0.0
    
    # Normalize for comparison
    answer_norm = final_answer.strip().lower()
    gt_norm = ground_truth.strip().lower()
    
    if answer_norm == gt_norm:
        return 1.0
    elif answer_norm in gt_norm or gt_norm in answer_norm:
        return 0.5  # Partial credit
    else:
        return 0.0


def _compute_math_reward(final_answer: str, ground_truth: str) -> float:
    """Math verification reward computation."""
    try:
        from math_verify import parse, verify
    except ImportError as e:
        raise ImportError(
            "math_verify package is required for mathematical reward computation. "
            "Install with: pip install math_verify"
        ) from e
        
    try:
        # if the final answer does NOT start with $$ and ends with $$
        if final_answer:
            if not final_answer.startswith("$$") and not final_answer.endswith("$$"):
                final_answer = f"$${final_answer}$$"
            if not ground_truth.startswith("$$") and not ground_truth.endswith("$$"):
                ground_truth = f"$${ground_truth}$$"
            parsed_answer = parse(str(final_answer))
            parsed_gt = parse(str(ground_truth))
            is_correct = verify(parsed_gt, parsed_answer)
        else:
            is_correct = False
        return 1.0 if is_correct else 0.0
        
    except Exception as e:
        # Fail fast - don't hide parsing errors with fallbacks
        raise ValueError(
            f"Math verification failed for answer='{final_answer}', ground_truth='{ground_truth}': {e}"
        ) from e


def two_stage_math_correctness_reward(
    # GRPO trainer passes these as positional/keyword args
    completions: List[str], 
    ground_truth: List[str] = None,
    questions: List[str] = None,
    prompts: List[str] = None,  # GRPO trainer passes this
    # Reasoner generation parameters (injected by TRL adapter)
    reasoner_server_url: str = None,
    reasoner_model_name: str = None,
    reasoner_max_tokens: int = None,
    reasoner_temperature: float = None,
    reasoner_top_p: float = None,
    reasoner_top_k: int = None,
    # Template consistency (should match training config)
    prompt_template_name: str = "two_stage_math_v1",
    **kwargs
) -> List[float]:
    """
    Two-stage pipeline reward with math verification.
    
    Captioner descriptions → Two-stage reasoning → Final answers → Math verification reward
    
    Args:
        completions: List of captioner descriptions (REQUIRED, from GRPO trainer)
        ground_truth: List of correct mathematical answers (REQUIRED, from dataset)
        questions: List of questions for each sample (can be None, extracted from prompts)
        prompts: List of prompt messages (from GRPO trainer, used if questions is None)
        reasoner_server_url: URL of the reasoner server (injected by TRL adapter)
        reasoner_model_name: Name of the reasoner model (injected by TRL adapter)
        reasoner_max_tokens: Max tokens for reasoner (injected by TRL adapter)
        reasoner_temperature: Temperature for reasoner (injected by TRL adapter)
        reasoner_top_p: Top-p for reasoner (injected by TRL adapter)
        reasoner_top_k: Top-k for reasoner (injected by TRL adapter)
        prompt_template_name: Template name for scaffold (injected by TRL adapter)
        **kwargs: Additional data from dataset
        
    Returns:
        List of rewards based on mathematical correctness
    """
    # Validate required parameters - fail fast
    if not completions:
        raise ValueError("completions list cannot be empty")
    
    # Handle ground_truth from kwargs (GRPO trainer passes extra dataset fields this way)
    if ground_truth is None:
        ground_truth = kwargs.get('ground_truth')
    if not ground_truth:
        raise ValueError("ground_truth list cannot be empty - ensure dataset has 'ground_truth' field")
    
    # Extract questions from prompts if not provided directly
    if kwargs.get('question') is not None:
        questions = kwargs.get('question')
    elif kwargs.get('questions') is not None:
        questions = kwargs.get('questions')

    if questions is None:
        if prompts is None:
            raise ValueError("Either 'questions' or 'prompts' must be provided")
        
        # Extract questions from prompt structure (same logic as pipeline_math_correctness_reward)
        extracted_questions = []
        for prompt in prompts:
            try:
                # Handle different prompt formats
                if isinstance(prompt, list):
                    # List of messages format - find user message with question
                    user_message = None
                    for msg in prompt:
                        if isinstance(msg, dict) and msg.get("role") == "user":
                            user_message = msg
                            break
                    
                    if user_message:
                        content = user_message.get("content", "")
                        if isinstance(content, list):
                            # Structured content - extract text parts
                            text_parts = []
                            for item in content:
                                if isinstance(item, dict) and item.get("type") == "text":
                                    text_parts.append(item.get("text", ""))
                            question = "\n".join(text_parts)
                        else:
                            question = str(content)
                    else:
                        question = str(prompt)  # Fallback
                else:
                    # String format
                    question = str(prompt)
                
                extracted_questions.append(question)
                
            except Exception as e:
                print(f"Warning: Failed to extract question from prompt {prompt}: {e}")
                extracted_questions.append(str(prompt))  # Fallback to prompt string
        
        questions = extracted_questions
    
    # Process completions to extract content
    completions = list(map(lambda x: x[0]['content'], completions))
    
    # Validate parameter compatibility
    if len(completions) != len(ground_truth):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != ground_truth({len(ground_truth)})")
    if len(completions) != len(questions):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != questions({len(questions)})")
    
    # Validate injected parameters from TRL adapter
    required_reasoner_params = [
        'reasoner_server_url', 'reasoner_model_name', 'reasoner_max_tokens',
        'reasoner_temperature', 'reasoner_top_p', 'reasoner_top_k'
    ]
    
    missing_reasoner_params = [p for p in required_reasoner_params if locals()[p] is None]
    
    if missing_reasoner_params:
        raise ValueError(
            f"Missing required reasoner parameters: {missing_reasoner_params}. "
            f"These should be injected by TRL adapter from config."
        )
    
    # Create two-stage scaffold for reasoning
    scaffold = _get_two_stage_scaffold(
        reasoner_server_url=reasoner_server_url,
        reasoner_model_name=reasoner_model_name,
        reasoner_max_tokens=reasoner_max_tokens,
        reasoner_temperature=reasoner_temperature,
        reasoner_top_p=reasoner_top_p,
        reasoner_top_k=reasoner_top_k,
        prompt_template_name=prompt_template_name
    )
    
    # Parallel processing configuration - use config values if available
    max_workers = kwargs.get('reward_parallel_workers', 32)
    # max_workers = 16
    enable_parallel = kwargs.get('reward_enable_parallel', True)
    
    if not enable_parallel:
        # Sequential processing fallback
        print(f"🔄 Processing {len(completions)} completions sequentially (parallel disabled)")
        max_workers = 1
    else:
        # Cap at reasonable limits and scale with completion count
        max_workers = min(len(completions), max_workers)
        print(f"🚀 Processing {len(completions)} completions with {max_workers} workers")
    
    def process_single_completion(args_tuple):
        """Process a single completion through the two-stage scaffold."""
        i, (description, gt, question) = args_tuple
        try:
            # Each thread gets its own scaffold instance to avoid conflicts
            thread_scaffold = _get_two_stage_scaffold(
                reasoner_server_url=reasoner_server_url,
                reasoner_model_name=reasoner_model_name,
                reasoner_max_tokens=reasoner_max_tokens,
                reasoner_temperature=reasoner_temperature,
                reasoner_top_p=reasoner_top_p,
                reasoner_top_k=reasoner_top_k,
                prompt_template_name=prompt_template_name
            )
            
            # Run description through two-stage reasoning to get final answer
            reasoning_result = thread_scaffold.reason_from_description(
                description=description,
                question=question,
                dataset="MathDataset",  # Ensure math prompts are used
                prompt_template_name=prompt_template_name,
                generation_kwargs=None  # Use scaffold's configured parameters
            )
            
            final_answer = reasoning_result["answer"]
            
            # Return reasoning result for math verification in main thread
            return i, reasoning_result, final_answer, gt, question
            
        except Exception as e:
            print(f"Error in two-stage math pipeline reward for sample {i}: {e}")
            raise  # Fail fast - don't hide errors with 0.0 rewards
    
    # Prepare arguments for parallel processing
    completion_args = list(enumerate(zip(completions, ground_truth, questions)))
    
    # Process completions in parallel (reasoning only)
    print(f"🚀 Processing {len(completions)} completions with {max_workers} workers...")
    
    reasoning_results = [None] * len(completions)  # Pre-allocate to maintain order
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_args = {
            executor.submit(process_single_completion, args): args 
            for args in completion_args
        }
        
        # Collect reasoning results as they complete
        completed_count = 0
        for future in as_completed(future_to_args):
            try:
                i, reasoning_result, final_answer, gt, question = future.result()
                reasoning_results[i] = (reasoning_result, final_answer, gt, question)  # Store all data
                completed_count += 1
                
                # Progress update
                if completed_count % 4 == 0 or completed_count == len(completions):
                    print(f"    Completed {completed_count}/{len(completions)} reasoning tasks...")
                    
            except Exception as e:
                print(f"Error processing completion: {e}")
                raise  # Fail fast
    
    print(f"✅ Completed all {len(completions)} reasoning tasks")
    
    # Now do math verification sequentially in main thread (very fast)
    print(f"🔢 Computing math verification rewards sequentially...")
    rewards = []
    
    for i, (reasoning_result, final_answer, gt, question) in enumerate(reasoning_results):
        try:
            # Use math verification in main thread (no signal issues)
            reward = _compute_math_reward(final_answer, gt)
            rewards.append(reward)
            
            # Debug logging for first few samples
            if i < 3:
                print(f"  Two-Stage Math Sample {i}: {final_answer} vs {gt} → {reward}")
                print(f"    Question: {question[:100]}...")
                print(f"    Success: {reasoning_result['success']}")
                print(f"    Reasoning: {reasoning_result['reasoning'][:200]}...")
                
        except Exception as e:
            print(f"Error computing math reward for sample {i}: {e}")
            raise  # Fail fast
    
    print(f"✅ Completed all {len(completions)} two-stage math verification rewards")
    return rewards


def _get_two_stage_scaffold(
    reasoner_server_url: str,
    reasoner_model_name: str,
    reasoner_max_tokens: int,
    reasoner_temperature: float,
    reasoner_top_p: float,
    reasoner_top_k: int,
    prompt_template_name: str = "two_stage_math_v1"
):
    """
    Create two-stage scaffold with server-based reasoner for reward function.
    
    Args:
        reasoner_server_url: URL of the reasoner server (REQUIRED)
        reasoner_model_name: Name of the reasoner model on the server (REQUIRED)
        reasoner_max_tokens: Max tokens for reasoner
        reasoner_temperature: Temperature for reasoner  
        reasoner_top_p: Top-p for reasoner
        reasoner_top_k: Top-k for reasoner (-1 to disable)
        prompt_template_name: Template name to use (should match training)
        
    Returns:
        TwoStageScaffold instance configured for reasoning-only mode
    """
    from ..scaffolds.two_stage import TwoStageScaffold
    from ..core.reasoner_interface import ReasonerInterface, ReasonerConfig
    from ..core.vlm_interface import VLMInterface, VLMConfig
    
    # Validate required parameters
    if not reasoner_server_url:
        raise ValueError("reasoner_server_url cannot be empty")
    if not reasoner_model_name:
        raise ValueError("reasoner_model_name cannot be empty")
    
    # Handle top_k parameter (convert -1 to None)
    top_k = None if reasoner_top_k <= 0 else reasoner_top_k
    
    # Create reasoner configuration for the server
    reasoner_config = ReasonerConfig(
        model_name=reasoner_model_name,  # Use actual model name from server
        model_type="openai",  # Generic API type
        api_base=reasoner_server_url,
        api_key="EMPTY",  # Not needed for local server
        timeout=300,  # 10 minute timeout for complex math reasoning
        max_tokens=reasoner_max_tokens,
        temperature=reasoner_temperature,
        top_p=reasoner_top_p,
        top_k=top_k
    )
    
    # Create dummy VLM config (won't be used in reasoning-only mode)
    vlm_config = VLMConfig(
        model_name="dummy",
        model_type="openai",
        api_base="http://dummy:8000/v1",
        api_key="EMPTY",
        timeout=30,
        max_tokens=1000,
        temperature=0.0,
        top_p=1.0,
        top_k=None
    )
    
    # Create interfaces
    vlm = VLMInterface.create(vlm_config)
    reasoner = ReasonerInterface.create(reasoner_config)
    
    # Create scaffold (VLM won't be used in reason_from_description)
    scaffold = TwoStageScaffold(
        vlm=vlm,  # Not used in reasoning-only mode
        reasoner=reasoner
    )
    
    return scaffold


def three_stage_math_correctness_reward(
    # GRPO trainer passes these as positional/keyword args
    completions: List[str], 
    ground_truth: List[str] = None,
    questions: List[str] = None,
    prompts: List[str] = None,  # GRPO trainer passes this
    # VLM and Reasoner server parameters (injected by TRL adapter)
    vlm_server_url: str = None,
    vlm_model_name: str = None,
    vlm_max_tokens: int = None,
    vlm_temperature: float = None,
    vlm_top_p: float = None,
    vlm_top_k: int = None,
    reasoner_server_url: str = None,
    reasoner_model_name: str = None,
    reasoner_max_tokens: int = None,
    reasoner_temperature: float = None,
    reasoner_top_p: float = None,
    reasoner_top_k: int = None,
    # Penalty for asking clarifying questions (injected by TRL adapter)
    question_penalty: float = None,
    # Template consistency (should match training config)
    prompt_template_name: str = "three_stage_math_v1",
    **kwargs
) -> List[float]:
    """
    Three-stage pipeline reward with math verification and question penalty.
    
    Captioner descriptions → Three-stage reasoning (with optional VLM clarification) → Final answers → Math verification reward
    
    Args:
        completions: List of captioner descriptions (REQUIRED, from GRPO trainer)
        ground_truth: List of correct mathematical answers (REQUIRED, from dataset)
        questions: List of questions for each sample (can be None, extracted from prompts)
        prompts: List of prompt messages (from GRPO trainer, used if questions is None)
        vlm_server_url: URL of the VLM server for clarifying questions (injected by TRL adapter)
        vlm_model_name: Name of the VLM model (injected by TRL adapter)
        vlm_max_tokens: Max tokens for VLM (injected by TRL adapter)
        vlm_temperature: Temperature for VLM (injected by TRL adapter)
        vlm_top_p: Top-p for VLM (injected by TRL adapter)
        vlm_top_k: Top-k for VLM (injected by TRL adapter)
        reasoner_server_url: URL of the reasoner server (injected by TRL adapter)
        reasoner_model_name: Name of the reasoner model (injected by TRL adapter)
        reasoner_max_tokens: Max tokens for reasoner (injected by TRL adapter)
        reasoner_temperature: Temperature for reasoner (injected by TRL adapter)
        reasoner_top_p: Top-p for reasoner (injected by TRL adapter)
        reasoner_top_k: Top-k for reasoner (injected by TRL adapter)
        question_penalty: Penalty applied when reasoner asks clarifying questions (REQUIRED, injected by TRL adapter)
        prompt_template_name: Template name for scaffold (injected by TRL adapter)
        **kwargs: Additional data from dataset
        
    Returns:
        List of rewards based on mathematical correctness with question penalties applied
    """
    # Validate required parameters - fail fast
    if not completions:
        raise ValueError("completions list cannot be empty")
    
    # Handle ground_truth from kwargs (GRPO trainer passes extra dataset fields this way)
    if ground_truth is None:
        ground_truth = kwargs.get('ground_truth')
    if not ground_truth:
        raise ValueError("ground_truth list cannot be empty - ensure dataset has 'ground_truth' field")
    
    # Validate question_penalty parameter - REQUIRED, no fallbacks
    if question_penalty is None:
        raise ValueError(
            "question_penalty parameter is required but not provided. "
            "This should be injected by TRL adapter from config. "
            "Add 'question_penalty: 0.1' to your training configuration."
        )
    
    if not isinstance(question_penalty, (int, float)) or question_penalty < 0:
        raise ValueError(
            f"question_penalty must be a non-negative number, got: {question_penalty} (type: {type(question_penalty)})"
        )
    
    print(f"🎯 Three-stage reward with question penalty: {question_penalty}")
    
    # Extract questions and image paths from prompts - REQUIRED for three-stage scaffold
    if prompts is None:
        raise ValueError(
            "prompts list is required for three-stage reward function. "
            "Three-stage scaffold needs image paths for VLM clarifying questions."
        )
    
    if kwargs.get('question') is not None:
        questions = kwargs.get('question')
    elif kwargs.get('questions') is not None:
        questions = kwargs.get('questions')
    
    # Extract questions and image paths from prompt structure
    extracted_questions = []
    image_paths = []
    
    for i, prompt in enumerate(prompts):
        try:
            question = ""
            image_path = None
            
            # Handle different prompt formats
            if isinstance(prompt, list):
                # List of messages format - find user message with question and image
                user_message = None
                for msg in prompt:
                    if isinstance(msg, dict) and msg.get("role") == "user":
                        user_message = msg
                        break
                
                if user_message:
                    content = user_message.get("content", "")
                    if isinstance(content, list):
                        # Structured content - extract text parts and image paths
                        text_parts = []
                        for item in content:
                            if isinstance(item, dict):
                                if item.get("type") == "text":
                                    text_parts.append(item.get("text", ""))
                                elif item.get("type") == "image":
                                    # Extract image path from URL field
                                    url = item.get("url")
                                    if url:
                                        image_path = str(url)
                        question = "\n".join(text_parts)
                    else:
                        question = str(content)
                else:
                    question = str(prompt)  # Fallback
            else:
                # String format
                question = str(prompt)
            
            # Fail fast if no image path found for this sample
            if image_path is None:
                raise ValueError(
                    f"No image path found in prompt {i} for three-stage reward function. "
                    f"Three-stage scaffold requires image paths for VLM clarifying questions. "
                    f"Prompt structure: {type(prompt)} with content type: {type(prompt) if not isinstance(prompt, list) else [type(msg.get('content', '')) if isinstance(msg, dict) else type(msg) for msg in prompt]}. "
                    f"Expected structured content with image URL in user message."
                )
            
            extracted_questions.append(question)
            image_paths.append(image_path)
            
        except Exception as e:
            raise ValueError(
                f"Failed to extract question/image from prompt {i}: {e}. "
                f"Three-stage reward function requires both questions and image paths. "
                f"Prompt: {prompt}"
            )
    
    # Use extracted questions if not provided directly
    if questions is None:
        questions = extracted_questions
    else:
        # If questions provided directly, still need to extract image paths
        if len(questions) != len(prompts):
            raise ValueError(f"Length mismatch: questions({len(questions)}) != prompts({len(prompts)})")
        
        # Validate we got image paths for all prompts
        if len(image_paths) != len(prompts):
            raise ValueError(
                f"Failed to extract image paths from all prompts. "
                f"Got {len(image_paths)} image paths from {len(prompts)} prompts. "
                f"Three-stage scaffold requires image access for VLM clarifying questions."
            )
    
    # Process completions to extract content
    completions = list(map(lambda x: x[0]['content'], completions))
    
    # Validate parameter compatibility
    if len(completions) != len(ground_truth):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != ground_truth({len(ground_truth)})")
    if len(completions) != len(questions):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != questions({len(questions)})")
    if len(completions) != len(image_paths):
        raise ValueError(f"Length mismatch: completions({len(completions)}) != image_paths({len(image_paths)})")
    
    # Validate all image paths are present (fail fast - no dummy paths)
    missing_image_indices = [i for i, path in enumerate(image_paths) if not path]
    if missing_image_indices:
        raise ValueError(
            f"Missing image paths for samples {missing_image_indices}. "
            f"Three-stage scaffold requires image access for VLM clarifying questions. "
            f"No fallbacks or dummy paths allowed - image paths must be present in all prompts."
        )
    
    # Validate injected parameters from TRL adapter
    required_vlm_params = [
        'vlm_server_url', 'vlm_model_name', 'vlm_max_tokens',
        'vlm_temperature', 'vlm_top_p', 'vlm_top_k'
    ]
    required_reasoner_params = [
        'reasoner_server_url', 'reasoner_model_name', 'reasoner_max_tokens',
        'reasoner_temperature', 'reasoner_top_p', 'reasoner_top_k'
    ]
    
    missing_vlm_params = [p for p in required_vlm_params if locals()[p] is None]
    missing_reasoner_params = [p for p in required_reasoner_params if locals()[p] is None]
    
    if missing_vlm_params:
        raise ValueError(
            f"Missing required VLM parameters for three-stage reward: {missing_vlm_params}. "
            f"These should be injected by TRL adapter from config."
        )
    
    if missing_reasoner_params:
        raise ValueError(
            f"Missing required reasoner parameters for three-stage reward: {missing_reasoner_params}. "
            f"These should be injected by TRL adapter from config."
        )
    
    # Create three-stage scaffold for reasoning
    scaffold = _get_three_stage_scaffold(
        vlm_server_url=vlm_server_url,
        vlm_model_name=vlm_model_name,
        vlm_max_tokens=vlm_max_tokens,
        vlm_temperature=vlm_temperature,
        vlm_top_p=vlm_top_p,
        vlm_top_k=vlm_top_k,
        reasoner_server_url=reasoner_server_url,
        reasoner_model_name=reasoner_model_name,
        reasoner_max_tokens=reasoner_max_tokens,
        reasoner_temperature=reasoner_temperature,
        reasoner_top_p=reasoner_top_p,
        reasoner_top_k=reasoner_top_k,
        prompt_template_name=prompt_template_name
    )
    
    # Log first few image paths for debugging
    print(f"🖼️ Using image paths for three-stage reasoning:")
    for i, path in enumerate(image_paths[:3]):
        print(f"  Sample {i}: {path}")
    if len(image_paths) > 3:
        print(f"  ... and {len(image_paths) - 3} more")
    
    # Parallel processing configuration
    max_workers = 32
    enable_parallel = kwargs.get('reward_enable_parallel', True)
    
    if not enable_parallel:
        # Sequential processing fallback
        print(f"🔄 Processing {len(completions)} completions sequentially (parallel disabled)")
        max_workers = 1
    else:
        # Cap at reasonable limits and scale with completion count
        max_workers = min(len(completions), max_workers)
        print(f"🚀 Processing {len(completions)} completions with {max_workers} workers")
    
    def process_single_completion(args_tuple):
        """Process a single completion through the three-stage scaffold."""
        i, (description, gt, question, image_path) = args_tuple
        
        # Validate image path exists for this sample (fail fast per sample)
        if not image_path:
            raise ValueError(
                f"Sample {i}: Missing image path for three-stage scaffold. "
                f"VLM clarifying questions require image access. No fallbacks allowed."
            )
        
        try:
            # Each thread gets its own scaffold instance to avoid conflicts
            thread_scaffold = _get_three_stage_scaffold(
                vlm_server_url=vlm_server_url,
                vlm_model_name=vlm_model_name,
                vlm_max_tokens=vlm_max_tokens,
                vlm_temperature=vlm_temperature,
                vlm_top_p=vlm_top_p,
                vlm_top_k=vlm_top_k,
                reasoner_server_url=reasoner_server_url,
                reasoner_model_name=reasoner_model_name,
                reasoner_max_tokens=reasoner_max_tokens,
                reasoner_temperature=reasoner_temperature,
                reasoner_top_p=reasoner_top_p,
                reasoner_top_k=reasoner_top_k,
                prompt_template_name=prompt_template_name
            )
            
            # Run description through three-stage reasoning to get final answer
            # Use validated image path (no dummy paths)
            reasoning_result = thread_scaffold.reason_from_description(
                description=description,
                question=question,
                image_path=image_path,  # Use validated image path - no fallbacks
                dataset="MathDataset",  # Ensure math prompts are used
                prompt_template_name=prompt_template_name,
                generation_kwargs={},  # Use scaffold's configured reasoner parameters
                vlm_generation_kwargs={}  # Use scaffold's configured VLM parameters
            )
            
            final_answer = reasoning_result["answer"]
            
            # Return reasoning result for math verification in main thread
            return i, reasoning_result, final_answer, gt, question
            
        except Exception as e:
            raise RuntimeError(f"Error in three-stage math pipeline reward for sample {i}: {e}")
    
    # Prepare arguments for parallel processing - include validated image paths
    completion_args = list(enumerate(zip(completions, ground_truth, questions, image_paths)))
    
    # Process completions in parallel (reasoning only)
    print(f"🚀 Processing {len(completions)} completions with {max_workers} workers...")
    
    reasoning_results = [None] * len(completions)
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_args = {
            executor.submit(process_single_completion, args): args 
            for args in completion_args
        }
        
        # Collect reasoning results as they complete
        completed_count = 0
        for future in as_completed(future_to_args):
            try:
                i, reasoning_result, final_answer, gt, question = future.result()
                reasoning_results[i] = (reasoning_result, final_answer, gt, question)  # Store all data
                completed_count += 1
                
                # Progress update
                if completed_count % 4 == 0 or completed_count == len(completions):
                    print(f"    Completed {completed_count}/{len(completions)} reasoning tasks...")
                    
            except Exception as e:
                print(f"Error processing completion: {e}")
                raise  # Fail fast
    
    print(f"✅ Completed all {len(completions)} reasoning tasks")
    
    # Now do math verification sequentially in main thread (very fast)
    print(f"🔢 Computing math verification rewards sequentially...")
    rewards = []
    
    # Track penalty ratio statistics
    total_successful = 0
    successful_with_penalty = 0
    
    for i, (reasoning_result, final_answer, gt, question) in enumerate(reasoning_results):
        try:
            # Use math verification in main thread (no signal issues)
            base_reward = _compute_math_reward(final_answer, gt)
            
            # Apply penalty for asking clarifying questions to incentivize self-contained descriptions
            penalty_applied = 0.0
            asked_question = reasoning_result.get('needs_more_info', False)
            
            # Track successful trajectories for penalty ratio
            if base_reward > 0:  # Only count successful answers
                total_successful += 1
                if asked_question:
                    successful_with_penalty += 1
            
            if asked_question:
                penalty_applied = question_penalty
                final_reward = max(0.0, base_reward - penalty_applied)  # Don't go below 0
                print(f"    🔍 Question asked - applying penalty: {base_reward:.3f} - {penalty_applied:.3f} = {final_reward:.3f}")
            else:
                final_reward = base_reward
                print(f"    ✅ No question asked - full reward: {final_reward:.3f}")
            
            rewards.append(final_reward)
            
            # Debug logging for first few samples
            if i < 3:
                print(f"  Three-Stage Math Sample {i}: {final_answer} vs {gt} → {final_reward}")
                print(f"    Base reward: {base_reward}, Penalty applied: {penalty_applied}")
                print(f"    Question: {question[:100]}...")
                print(f"    Success: {reasoning_result['success']}")
                print(f"    Asked question: {asked_question}")
                print(f"    Iterations: {reasoning_result.get('iterations', 1)}")
                
        except Exception as e:
            print(f"Error computing math reward for sample {i}: {e}")
            raise  # Fail fast
    
    # Log penalty ratio statistics
    if total_successful > 0:
        penalty_ratio = successful_with_penalty / total_successful
        print(f"📊 Penalty ratio: {successful_with_penalty}/{total_successful} = {penalty_ratio:.3f} ({penalty_ratio*100:.1f}%)")
        print(f"   ├─ Successful trajectories total: {total_successful}")
        print(f"   ├─ Successful with penalty: {successful_with_penalty}")
        print(f"   └─ Successful without penalty: {total_successful - successful_with_penalty}")
        
        # Log to wandb if available
        try:
            import wandb
            if wandb.run is not None:
                # Check if step is provided in kwargs
                current_step = kwargs.get('current_step', kwargs.get('step'))
                log_data = {
                    "three_stage/penalty_ratio": penalty_ratio,
                    "three_stage/penalty_ratio_percent": penalty_ratio * 100,
                    "three_stage/successful_trajectories_total": total_successful,
                    "three_stage/successful_with_penalty": successful_with_penalty,
                    "three_stage/successful_without_penalty": total_successful - successful_with_penalty,
                    "three_stage/question_penalty_value": question_penalty,
                }
                if current_step is not None:
                    wandb.log(log_data, step=current_step)
                else:
                    wandb.log(log_data)
        except (ImportError, Exception) as e:
            # Wandb not available or not initialized - that's fine
            pass
    else:
        penalty_ratio = 0.0
        print(f"📊 Penalty ratio: 0/0 = N/A (no successful trajectories)")
        
        # Log to wandb if available (even when no successful trajectories)
        try:
            import wandb
            if wandb.run is not None:
                # Check if step is provided in kwargs
                current_step = kwargs.get('current_step', kwargs.get('step'))
                log_data = {
                    "three_stage/penalty_ratio": penalty_ratio,
                    "three_stage/penalty_ratio_percent": 0.0,
                    "three_stage/successful_trajectories_total": 0,
                    "three_stage/successful_with_penalty": 0,
                    "three_stage/successful_without_penalty": 0,
                    "three_stage/question_penalty_value": question_penalty,
                }
                if current_step is not None:
                    wandb.log(log_data, step=current_step)
                else:
                    wandb.log(log_data)
        except (ImportError, Exception) as e:
            # Wandb not available or not initialized - that's fine
            pass
    
    # Log additional reward statistics to wandb
    try:
        import wandb
        import numpy as np
        if wandb.run is not None:
            rewards_array = np.array(rewards)
            # Check if step is provided in kwargs
            current_step = kwargs.get('current_step', kwargs.get('step'))
            log_data = {
                "three_stage/avg_reward": np.mean(rewards_array),
                "three_stage/median_reward": np.median(rewards_array),
                "three_stage/min_reward": np.min(rewards_array),
                "three_stage/max_reward": np.max(rewards_array),
                "three_stage/std_reward": np.std(rewards_array),
                "three_stage/total_samples": len(rewards),
                "three_stage/zero_reward_count": np.sum(rewards_array == 0),
                "three_stage/positive_reward_count": np.sum(rewards_array > 0),
            }
            if current_step is not None:
                wandb.log(log_data, step=current_step)
            else:
                wandb.log(log_data)
    except (ImportError, Exception) as e:
        # Wandb or numpy not available - that's fine
        pass
    
    # Save comprehensive debug data if requested
    debug_data_dir = kwargs.get('debug_data_dir')
    if debug_data_dir:
        try:
            import os
            import json
            import time
            import hashlib
            import numpy as np
        
            os.makedirs(debug_data_dir, exist_ok=True)
            timestamp = int(time.time())
            
            # Extract sample IDs if available (for grouping 8 generations per sample)
            sample_ids = kwargs.get('sample_ids', kwargs.get('original_indices', kwargs.get('indices')))
            batch_id = kwargs.get('batch_id', kwargs.get('step', timestamp))
            
            # Collect comprehensive debug data for each generation
            debug_data = []
            for i, (reasoning_result, final_answer, gt, question) in enumerate(reasoning_results):
                # Extract all available information from reasoning result
                base_reward = _compute_math_reward(final_answer, gt)
                asked_question = reasoning_result.get('needs_more_info', False)
                penalty_applied = question_penalty if asked_question else 0.0
                
                # Create a comprehensive sample record
                sample_debug = {
                    # Sample identification
                    "generation_index": i,
                    "sample_id": sample_ids[i] if sample_ids and i < len(sample_ids) else None,
                    "batch_id": batch_id,
                    "timestamp": timestamp,
                    
                    # Input data
                    "question": question,
                    "ground_truth": gt,
                    "captioner_description": completions[i],
                    "image_path": image_paths[i],
                    
                    # Three-stage reasoning process
                    "reasoning_process": {
                        "stage1_description": completions[i],  # Initial captioner description
                        "stage2_decision": {
                            "status": "SOLVED" if not asked_question else "NEED_MORE_INFO",
                            "reasoning": reasoning_result.get('reasoning', ''),
                            "requested_clarification": reasoning_result.get('clarifying_question', None),
                        },
                        "stage2_5_clarification": {
                            "question_asked": asked_question,
                            "clarifying_question": reasoning_result.get('clarifying_question', None),
                            "additional_info_provided": reasoning_result.get('additional_info', None),
                        } if asked_question else None,
                        "stage3_final_reasoning": {
                            "reasoning_text": reasoning_result.get('reasoning', ''),
                            "final_answer": final_answer,
                            "success": reasoning_result.get('success', False),
                        }
                    },
                    
                    # Execution metadata
                    "execution": {
                        "success": reasoning_result.get('success', False),
                        "iterations": reasoning_result.get('iterations', 1),
                        "termination_reason": reasoning_result.get('termination_reason', 'unknown'),
                        "total_time_seconds": reasoning_result.get('total_time', 0.0),
                        "error": reasoning_result.get('error', None),
                    },
                    
                    # Reward calculation
                    "reward_calculation": {
                        "base_reward": base_reward,
                        "penalty_applied": penalty_applied,
                        "final_reward": rewards[i],
                        "question_penalty_value": question_penalty,
                        "math_correctness": base_reward > 0,
                    },
                    
                    # Additional metadata from kwargs
                    "metadata": {
                        k: v for k, v in kwargs.items() 
                        if k not in ['debug_data_dir', 'sample_ids', 'original_indices', 'indices'] 
                        and not callable(v) and len(str(v)) < 1000  # Avoid saving large objects
                    }
                }
                debug_data.append(sample_debug)
            
            # Create summary statistics
            summary = {
                "timestamp": timestamp,
                "batch_id": batch_id,
                "total_generations": len(completions),
                "unique_samples": len(set(sample_ids)) if sample_ids else "unknown",
                "generations_per_sample": len(completions) // len(set(sample_ids)) if sample_ids else "unknown",
                
                # Performance statistics
                "reward_statistics": {
                    "avg_reward": float(sum(rewards) / len(rewards)) if rewards else 0.0,
                    "min_reward": float(min(rewards)) if rewards else 0.0,
                    "max_reward": float(max(rewards)) if rewards else 0.0,
                    "std_reward": float(np.std(rewards)) if len(rewards) > 1 else 0.0,
                    "positive_reward_count": sum(1 for r in rewards if r > 0),
                    "zero_reward_count": sum(1 for r in rewards if r == 0),
                },
                
                # Penalty statistics
                "penalty_statistics": {
                    "question_penalty_value": question_penalty,
                    "penalty_ratio": (successful_with_penalty / total_successful) if total_successful > 0 else 0.0,
                    "successful_trajectories_total": total_successful,
                    "successful_with_penalty": successful_with_penalty,
                    "successful_without_penalty": total_successful - successful_with_penalty,
                    "clarifying_questions_asked": sum(1 for result in reasoning_results if result[0].get('needs_more_info', False)),
                },
                
                # Reasoning statistics
                "reasoning_statistics": {
                    "successful_reasoning": sum(1 for result in reasoning_results if result[0].get('success', False)),
                    "avg_iterations": sum(result[0].get('iterations', 1) for result in reasoning_results) / len(reasoning_results),
                    "avg_reasoning_time": sum(result[0].get('total_time', 0.0) for result in reasoning_results) / len(reasoning_results),
                }
            }
            
            # Save main debug file
            debug_file = os.path.join(debug_data_dir, f"three_stage_reward_debug_{timestamp}.json")
            with open(debug_file, 'w', encoding='utf-8') as f:
                json.dump({
                    "summary": summary,
                    "generations": debug_data
                }, f, indent=2, ensure_ascii=False)
            
            # Also save a trajectory-friendly format if we have sample IDs
            if sample_ids:
                # Group by sample ID for easier trajectory analysis
                trajectories = {}
                for sample in debug_data:
                    sid = sample.get('sample_id')
                    if sid is not None:
                        if sid not in trajectories:
                            trajectories[sid] = []
                        trajectories[sid].append(sample)
                
                trajectory_file = os.path.join(debug_data_dir, f"three_stage_trajectories_{timestamp}.json")
                with open(trajectory_file, 'w', encoding='utf-8') as f:
                    json.dump({
                        "summary": summary,
                        "trajectories": trajectories
                    }, f, indent=2, ensure_ascii=False)
                
                print(f"💾 Debug data saved to: {debug_file}")
                print(f"💾 Trajectory data saved to: {trajectory_file}")
                print(f"   📊 {len(trajectories)} unique samples with {len(debug_data)} total generations")
            else:
                print(f"💾 Debug data saved to: {debug_file}")
                print(f"   ⚠️  No sample IDs available - trajectory grouping not possible")
            
        except Exception as e:
            print(f"Warning: Failed to save debug data: {e}")
            import traceback
            traceback.print_exc()
    
    print(f"✅ Completed all {len(completions)} three-stage math verification rewards")
    return rewards


def _get_three_stage_scaffold(
    vlm_server_url: str,
    vlm_model_name: str,
    vlm_max_tokens: int,
    vlm_temperature: float,
    vlm_top_p: float,
    vlm_top_k: int,
    reasoner_server_url: str,
    reasoner_model_name: str,
    reasoner_max_tokens: int,
    reasoner_temperature: float,
    reasoner_top_p: float,
    reasoner_top_k: int,
    prompt_template_name: str = "three_stage_math_v1"
):
    """
    Create three-stage scaffold with server-based VLM and reasoner for reward function.
    
    Args:
        vlm_server_url: URL of the VLM server for clarifying questions (REQUIRED)
        vlm_model_name: Name of the VLM model on the server (REQUIRED)
        vlm_max_tokens: Max tokens for VLM
        vlm_temperature: Temperature for VLM
        vlm_top_p: Top-p for VLM
        vlm_top_k: Top-k for VLM (-1 to disable)
        reasoner_server_url: URL of the reasoner server (REQUIRED)
        reasoner_model_name: Name of the reasoner model on the server (REQUIRED)
        reasoner_max_tokens: Max tokens for reasoner
        reasoner_temperature: Temperature for reasoner  
        reasoner_top_p: Top-p for reasoner
        reasoner_top_k: Top-k for reasoner (-1 to disable)
        prompt_template_name: Template name to use (should match training)
        
    Returns:
        ThreeStageScaffold instance configured for reasoning-only mode
    """
    from ..scaffolds.three_stage import ThreeStageScaffold
    from ..core.reasoner_interface import ReasonerInterface, ReasonerConfig
    from ..core.vlm_interface import VLMInterface, VLMConfig
    
    # Validate required parameters
    if not vlm_server_url:
        raise ValueError("vlm_server_url cannot be empty")
    if not vlm_model_name:
        raise ValueError("vlm_model_name cannot be empty")
    if not reasoner_server_url:
        raise ValueError("reasoner_server_url cannot be empty")
    if not reasoner_model_name:
        raise ValueError("reasoner_model_name cannot be empty")
    
    # Handle top_k parameters (convert -1 to None)
    vlm_top_k = None if vlm_top_k <= 0 else vlm_top_k
    reasoner_top_k = None if reasoner_top_k <= 0 else reasoner_top_k
    
    # Create VLM configuration for the server
    vlm_config = VLMConfig(
        model_name=vlm_model_name,
        model_type="vllm",  # Use vLLM interface
        api_base=vlm_server_url,
        api_key="EMPTY",
        timeout=900,
        max_tokens=vlm_max_tokens,
        temperature=vlm_temperature,
        top_p=vlm_top_p,
        top_k=vlm_top_k,
    )
    
    # Create reasoner configuration for the server
    reasoner_config = ReasonerConfig(
        model_name=reasoner_model_name,
        model_type="api",  # Use generic API interface
        api_base=reasoner_server_url,
        api_key="EMPTY",
        timeout=900,
        max_tokens=reasoner_max_tokens,
        temperature=reasoner_temperature,
        top_p=reasoner_top_p,
        top_k=reasoner_top_k,
    )
    
    # Create interfaces
    vlm = VLMInterface.create(vlm_config)
    reasoner = ReasonerInterface.create(reasoner_config)
    
    # Create three-stage scaffold
    scaffold = ThreeStageScaffold(
        vlm=vlm,
        reasoner=reasoner,
        # Template and configuration for consistent training
        debug_mode=False,  # Reduce output in reward function
        enable_verification=False,  # Skip verification for speed
        enable_vlm_confidence=False,  # Disable confidence experiments in reward function
    )
    
    return scaffold


__all__ = [
    "correctness_reward",
    "length_reward", 
    "math_correctness_reward",
    "pipeline_correctness_reward",
    "pipeline_math_correctness_reward",
    "two_stage_math_correctness_reward",
    "three_stage_math_correctness_reward",
] 