"""
TRL training adapter for framework-agnostic VLM training.

This module implements TRL-specific training logic while maintaining
the same interface as other training frameworks.
"""

import time
import logging
from typing import Dict, Any, Optional, List, Union
from pathlib import Path
import time

import torch
from transformers import PreTrainedModel, AutoModel

from .base_trainer import BaseTrainingAdapter, TrainingConfig, TrainingResult

# Debug TRL import issue
import sys
import os

# Check if TRL is in editable mode
print(f"[TRLAdapter] Checking for TRL installations...")
for path in sys.path:
    trl_path = os.path.join(path, 'trl')
    if os.path.exists(trl_path) and os.path.isdir(trl_path):
        # Check if this is an editable install
        egg_link = os.path.join(os.path.dirname(trl_path), 'trl.egg-link')
        is_editable = os.path.exists(egg_link)
        print(f"[TRLAdapter] Found TRL at: {trl_path} (editable: {is_editable})")

try:
    import trl
    print(f"[TRLAdapter] TRL imported successfully from: {trl.__file__}")
    if hasattr(trl, '__version__'):
        print(f"[TRLAdapter] TRL version: {trl.__version__}")
    
    # Check if this is the expected editable install
    expected_path = "/home/<ANONYMIZED>/Vision-R1/trl"
    if trl.__file__ and expected_path in trl.__file__:
        print(f"[TRLAdapter] ✅ Using correct editable TRL install")
    elif trl.__file__:
        print(f"[TRLAdapter] ⚠️  WARNING: Using unexpected TRL from {trl.__file__}")
        print(f"[TRLAdapter] Expected: {expected_path}")
    else:
        print(f"[TRLAdapter] TRL.__file__ is None (lazy loading)")
    
    # Try importing GRPOTrainer through the lazy loader
    try:
        from trl import GRPOTrainer
        print(f"[TRLAdapter] GRPOTrainer imported successfully through lazy loader")
        # Check if GRPOTrainer.__init__ accepts data_collator
        import inspect
        sig = inspect.signature(GRPOTrainer.__init__)
        params = list(sig.parameters.keys())
        print(f"[TRLAdapter] GRPOTrainer.__init__ parameters: {params}")
        if 'data_collator' in params:
            print(f"[TRLAdapter] ✅ data_collator parameter found in GRPOTrainer")
        else:
            print(f"[TRLAdapter] ❌ data_collator parameter NOT found in GRPOTrainer")
    except ImportError as e:
        print(f"[TRLAdapter] Failed to import GRPOTrainer through lazy loader: {e}")
        # Try direct import
        try:
            from trl.trainer.grpo_trainer import GRPOTrainer
            print(f"[TRLAdapter] GRPOTrainer imported successfully through direct import")
            # Check parameters again
            import inspect
            sig = inspect.signature(GRPOTrainer.__init__)
            params = list(sig.parameters.keys())
            print(f"[TRLAdapter] GRPOTrainer.__init__ parameters (direct): {params}")
        except ImportError as e2:
            print(f"[TRLAdapter] Failed to import GRPOTrainer directly: {e2}")
except ImportError as e:
    print(f"[TRLAdapter] Failed to import TRL: {e}")

# TRL is loaded correctly - let's proceed

class TRLTrainingAdapter(BaseTrainingAdapter):
    """
    TRL (Transformer Reinforcement Learning) training adapter.
    
    This adapter provides integration with the TRL library for training
    VLM components using SFT, PPO, DPO, and other methods.
    """
    
    def __init__(self):
        """Initialize TRL training adapter."""
        super().__init__("trl")
        self.logger = logging.getLogger(__name__)
        self._trl_available = None
        self._required_modules = [
            "trl",
            "peft", 
            "accelerate"
        ]
    
    def is_available(self) -> bool:
        """
        Check if TRL and dependencies are available.
        
        Returns:
            True if TRL can be used, False otherwise
        """
        if self._trl_available is not None:
            return self._trl_available
        
        try:
            import trl
            import peft
            import accelerate
            from transformers import TrainingArguments
            
            # bitsandbytes is only required for quantization, not for regular PEFT
            # We'll check for it only when quantization is actually used
            
            self._trl_available = True
            return True
            
        except ImportError as e:
            print(f"TRL not available: {e}")
            self._trl_available = False
            return False
    
    def initialize_model(
        self,
        config: TrainingConfig,
        scaffold: 'BaseReasoningScaffold'
    ) -> PreTrainedModel:
        """
        Initialize model for TRL training with QLoRA support.
        
        Args:
            config: Training configuration
            scaffold: Reasoning scaffold to train
            
        Returns:
            Model prepared for TRL training
        """
        if not self.is_available():
            raise RuntimeError("TRL is not available in this environment")
        
        # Simple solution: install bitsandbytes with CPU support if needed
        
        from transformers import (
            AutoTokenizer, 
            AutoModelForCausalLM, 
            AutoModelForVision2Seq,
            AutoProcessor
        )
        from peft import LoraConfig, get_peft_model, TaskType
        
        # Set up quantization config only if quantization is enabled
        bnb_config = None
        use_quantization = (hasattr(config, 'load_in_4bit') and config.load_in_4bit) or \
                          (hasattr(config, 'load_in_8bit') and config.load_in_8bit)
        
        if use_quantization:
            try:
                from transformers import BitsAndBytesConfig
                import bitsandbytes
                from peft import prepare_model_for_kbit_training
                
                if hasattr(config, 'load_in_4bit') and config.load_in_4bit:
                    bnb_config = BitsAndBytesConfig(
                        load_in_4bit=True,
                        bnb_4bit_quant_type=getattr(config, 'bnb_4bit_quant_type', 'nf4'),
                        bnb_4bit_use_double_quant=getattr(config, 'bnb_4bit_use_double_quant', True),
                        bnb_4bit_compute_dtype=torch.bfloat16 if getattr(config, 'bnb_4bit_compute_dtype', 'bfloat16') == 'bfloat16' else torch.float16,
                    )
                elif hasattr(config, 'load_in_8bit') and config.load_in_8bit:
                    bnb_config = BitsAndBytesConfig(load_in_8bit=True)
                    
            except ImportError:
                raise RuntimeError("Quantization requested but bitsandbytes is not available. "
                                 "Install with: pip install bitsandbytes")
        else:
            bnb_config = None
        
        # Load VLM model with quantization
        try:
            # Check if we should use Qwen wrapper for InternVL3 compatibility
            use_qwen_wrapper = getattr(config, 'use_qwen_wrapper', False)
            qwen_wrapper_model = getattr(config, 'qwen_wrapper_model', None)
            
            if use_qwen_wrapper and qwen_wrapper_model:
                print(f"Loading Qwen model with wrapper: {qwen_wrapper_model}")
                # Prefer minimal wrapper (native Qwen tensors: 2D features + image_grid_thw)
                # Fall back to drop-in wrapper if minimal wrapper is unavailable
                is_distributed = False
                try:
                    is_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1
                except Exception:
                    is_distributed = False
                wrapper_device_map = None if is_distributed else "auto"
                try:
                    from qwen_min_wrapper import load_qwen_min_wrapper
                    model, processor = load_qwen_min_wrapper(
                        qwen_wrapper_model,
                        torch_dtype=torch.bfloat16 if config.mixed_precision else torch.float32,
                        device_map=wrapper_device_map,
                        quantization_config=bnb_config,
                    )
                    print("✅ Loaded Qwen minimal wrapper")
                except Exception as e:
                    print(f"⚠️  Minimal wrapper unavailable ({e}); falling back to drop-in wrapper")
                    from qwen_dropin_wrapper import load_qwen_as_internvl
                    model, processor = load_qwen_as_internvl(
                        qwen_wrapper_model,
                        torch_dtype=torch.bfloat16 if config.mixed_precision else torch.float32,
                        device_map=wrapper_device_map,
                        quantization_config=bnb_config,
                    )
                
                # Extract tokenizer from processor for TRL compatibility
                tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
                print(f"✅ Loaded Qwen wrapper: model={type(model)}, processor={type(processor)}")
                
                # Override the training_model_name for tokenizer loading to use the actual Qwen model
                config.training_model_name = qwen_wrapper_model
                
            # Check if this is an InternVL model
            elif "internvl" in config.training_model_name.lower():
                print(f"Loading InternVL3 model: {config.training_model_name}")
                
                if "-hf" in config.training_model_name.lower():
                    # Use AutoModelForCausalLM for the HuggingFace-compatible version
                    # The -hf version is still a causal LM but with proper HF integration
                    from transformers import InternVLForConditionalGeneration
                    model = InternVLForConditionalGeneration.from_pretrained(
                        config.training_model_name,
                        quantization_config=bnb_config,
                        torch_dtype=torch.bfloat16 if config.mixed_precision else torch.float32,
                        trust_remote_code=True,
                        low_cpu_mem_usage=True
                    )
                else:
                    # Use AutoModel for non-HF version
                    from transformers import AutoModel
                    model = AutoModel.from_pretrained(
                        config.training_model_name,
                        quantization_config=bnb_config,
                        torch_dtype=torch.bfloat16 if config.mixed_precision else torch.float32,
                        trust_remote_code=True,  # Required for InternVL3
                        use_flash_attn=True,  # InternVL3 supports flash attention
                        low_cpu_mem_usage=True
                    )
                    # Override forward method to handle TRL's dual input format
                    # TRL passes both input_ids and inputs_embeds, but InternVL3 expects only input_ids
                    original_forward = model.forward
                    
                    def trl_compatible_forward(*args, **kwargs):
                        print(f"args: {args}")
                        print(f"kwargs: {kwargs}")
                        # Remove inputs_embeds if present to avoid conflicts
                        if 'inputs_embeds' in kwargs:
                            del kwargs['inputs_embeds']
                        if 'logits_to_keep' in kwargs:
                            print(f"logits_to_keep: {kwargs['logits_to_keep']}")
                            del kwargs['logits_to_keep']
                        return original_forward(*args, **kwargs)
                    
                    model.forward = trl_compatible_forward
            else:
                model_name_lower = config.training_model_name.lower()
                if ("qwen" in model_name_lower) and ("vl" in model_name_lower or "vision" in model_name_lower):
                    # Qwen2.5-VL and similar VLMs must not be loaded with AutoModelForCausalLM
                    # Prefer the dedicated class if available, otherwise use AutoModelForVision2Seq
                    qwen_vl_class = None
                    try:
                        from transformers import Qwen2_5_VLForConditionalGeneration as _QwenVLForConditionalGeneration  # type: ignore
                        qwen_vl_class = _QwenVLForConditionalGeneration
                    except Exception:
                        try:
                            from transformers import Qwen2VLForConditionalGeneration as _QwenVLForConditionalGeneration  # type: ignore
                            qwen_vl_class = _QwenVLForConditionalGeneration
                        except Exception:
                            qwen_vl_class = None

                    if qwen_vl_class is not None:
                        model = qwen_vl_class.from_pretrained(
                            config.training_model_name,
                            quantization_config=bnb_config,
                            torch_dtype=torch.bfloat16 if config.mixed_precision else torch.float32,
                            trust_remote_code=True,
                            low_cpu_mem_usage=True
                        )
                    else:
                        model = AutoModelForVision2Seq.from_pretrained(
                            config.training_model_name,
                            quantization_config=bnb_config,
                            torch_dtype=torch.bfloat16 if config.mixed_precision else torch.float32,
                            trust_remote_code=True,
                            low_cpu_mem_usage=True
                        )
                else:
                    model = AutoModelForCausalLM.from_pretrained(
                        config.training_model_name,
                        quantization_config=bnb_config,
                        torch_dtype=torch.bfloat16 if config.mixed_precision else torch.float32,
                        trust_remote_code=True,
                        low_cpu_mem_usage=True
                    )
            # Load processor and tokenizer for VLM (only if not using Qwen wrapper)
            if not (use_qwen_wrapper and qwen_wrapper_model):
                try:
                    processor = AutoProcessor.from_pretrained(config.training_model_name, trust_remote_code=True)
                    print(f"✅ Loaded processor: {type(processor)}")
                    
                    # Extract tokenizer from processor
                    if hasattr(processor, 'tokenizer'):
                        tokenizer = processor.tokenizer
                        print(f"✅ Extracted tokenizer from processor: {type(tokenizer)}")
                    else:
                        # Fallback: load tokenizer separately
                        tokenizer = AutoTokenizer.from_pretrained(config.training_model_name, trust_remote_code=True)
                        print(f"⚠️  Loaded tokenizer separately: {type(tokenizer)}")
                except Exception as e:
                    print(f"❌ Failed to load processor, falling back to tokenizer: {e}")
                    import traceback
                    traceback.print_exc()
                    # Fallback to tokenizer only
                    tokenizer = AutoTokenizer.from_pretrained(config.training_model_name, trust_remote_code=True)
                    processor = None  # Don't set processor to tokenizer - keep them separate
                    print(f"⚠️  Using tokenizer only: {type(tokenizer)}")
        
        except Exception as e:
            print(f"Failed to load model: {e}")
            raise RuntimeError(f"Could not load model {config.training_model_name}: {e}")
        
        # Add padding token if needed
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            if hasattr(model.config, 'pad_token_id'):
                model.config.pad_token_id = model.config.eos_token_id
        
        # Prepare model for k-bit training if using quantization
        if bnb_config is not None:
            from peft import prepare_model_for_kbit_training
            model = prepare_model_for_kbit_training(model)
        
        # Apply LoRA if requested (skip for GRPO as it handles PEFT internally)
        if config.use_lora and config.training_method != "grpo":
            # Determine target modules based on model architecture
            target_modules = self._get_target_modules_for_model(model, config)
            print(f"target_modules: {target_modules}")
            lora_config = LoraConfig(
                r=config.lora_r,
                lora_alpha=config.lora_alpha,
                target_modules=target_modules,
                lora_dropout=config.lora_dropout,
                bias="none",
                task_type=TaskType.CAUSAL_LM,
            )
            # print("pre-peft trainable parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
            model = get_peft_model(model, lora_config)
            # print("post-peft trainable parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
        elif config.use_lora and config.training_method == "grpo":
            print("🔧 Skipping PEFT application in TRL adapter - GRPO trainer will handle it internally")
        
        # Set up parameter freezing
        self.setup_freezing(model, config)
        
        # Enable gradient checkpointing if configured
        if config.gradient_checkpointing:
            model.enable_input_require_grads()
            model.gradient_checkpointing_enable()
        
        # Ensure model is in training mode and gradients are enabled
        model.train()
        
        # For distributed training, ensure all parameters are properly configured
        for param in model.parameters():
            if param.requires_grad:
                # Ensure gradients can be computed
                param.requires_grad_(True)
        
        # Count trainable parameters after all setup
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        if len(trainable_params) == 0:
            raise RuntimeError("No trainable parameters found after model setup!")
        
        # Debug model config type
        if hasattr(model, 'config'):
            print(f"🔍 Model config type: {type(model.config)}")
            if isinstance(model.config, dict):
                print(f"🔍 Config keys: {list(model.config.keys())}")
            else:
                print(f"🔍 Config class: {model.config.__class__.__name__}")
        
        self.model = model
        self.tokenizer = tokenizer
        self.processor = processor
        self.current_config = config

        print(f"Trainable parameters: {len(trainable_params)}")
        total_params = sum(p.numel() for p in model.parameters())
        trainable_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_count:,} ({trainable_count/total_params*100:.2f}%)")
        return model
    
    def _get_target_modules_for_model(self, model, config) -> List[str]:
        """
        Get appropriate target modules for LoRA based on model architecture.
        
        Args:
            model: The model instance
            config: Training configuration
            
        Returns:
            List of target module names
        """
        # Default target modules for common architectures
        default_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"]
        
        # Check if config specifies target modules
        if hasattr(config, 'lora_target_modules') and config.lora_target_modules:
            return config.lora_target_modules
        
        # Auto-detect based on model architecture
        model_name = str(type(model).__name__).lower()
        
        if "qwen" in model_name or "internvl" in model_name:
            # Qwen models typically have these modules
            modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"]
            
            # For VLM models, we might want to include vision components
            if "vl" in model_name or "vision" in model_name:
                # Add vision-related modules if needed
                # Note: Be careful about training vision encoder - often better to keep frozen
                pass
            
            return modules
        
        elif "llama" in model_name:
            return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"]
        
        else:
            # Use default for unknown architectures
            return default_modules
    
    def prepare_trainer(
        self,
        model: PreTrainedModel,
        train_dataset,
        eval_dataset,
        config: TrainingConfig
    ):
        """
        Prepare TRL trainer with optimized settings for VLM training.
        
        Args:
            model: Initialized model
            train_dataset: Training dataset
            eval_dataset: Evaluation dataset (optional)
            config: Training configuration
            
        Returns:
            TRL trainer object
        """
        from transformers import TrainingArguments
        
        # Store config for use in dataset conversion and template loading
        self.current_config = config
        
        # Set up training arguments optimized for VLM training
        training_args = TrainingArguments(
            output_dir=config.output_dir,
            per_device_train_batch_size=config.batch_size,
            per_device_eval_batch_size=config.batch_size,
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            learning_rate=config.learning_rate,
            weight_decay=config.weight_decay,
            warmup_ratio=config.warmup_ratio,
            num_train_epochs=config.max_epochs,
            logging_steps=config.logging_steps,
            save_steps=config.save_steps,
            eval_steps=config.eval_steps if eval_dataset else None,
            eval_strategy="steps" if eval_dataset else "no",
            save_strategy="steps",
            load_best_model_at_end=True if eval_dataset else False,
            metric_for_best_model="eval_loss" if eval_dataset else None,
            greater_is_better=False,
            # max_prompt_length=None,
            
            # Mixed precision settings optimized for VLM
            fp16=False,  # Use bfloat16 instead for VLMs
            bf16=True,   # Better for VLM training
            tf32=True,   # Faster on A100/H100
            
            gradient_checkpointing=config.gradient_checkpointing,
            gradient_checkpointing_kwargs={'use_reentrant':True}, # OR gradient_checkpointing_kwargs={'use_reentrant':True} 

            dataloader_pin_memory=False,  # Better for large VLM models
            dataloader_num_workers=2,      # Conservative for VLM
            remove_unused_columns=False,   # Keep image data for VLM
            
            # Optimizer settings optimized for training type
            optim=getattr(config, 'optim', 'adamw_torch'),  # Use config setting or default to standard AdamW
            adam_beta1=0.9,
            adam_beta2=0.95,  # Slightly lower than default for VLM stability
            adam_epsilon=1e-7,
            max_grad_norm=config.max_grad_norm,  # Gradient clipping for stability
            
            # Learning rate scheduler
            lr_scheduler_type=getattr(config, 'lr_scheduler_type', 'cosine_with_restarts'),
            warmup_steps=getattr(config, 'warmup_steps', 0),  # Use 0 instead of None
            
            # Advanced settings
            group_by_length=False,  # Better for conversation data
            length_column_name="length",
            
            # Report to wandb if configured
            report_to=getattr(config, 'report_to', ["wandb", "tensorboard"]),
            run_name=getattr(config, 'run_name', None),
            
            # NEFTune for improved performance
            neftune_noise_alpha=getattr(config, 'neftune_noise_alpha', None),
            
            # Compilation settings
            # torch_compile=getattr(config, 'torch_compile', False),  # Disable by default for multimodal models
            
            # Distributed training settings - let PyTorch handle device placement
            ddp_find_unused_parameters=False,  # More efficient for large models
            ddp_broadcast_buffers=False,       # More efficient
            
            # Checkpoint resumption
            resume_from_checkpoint=getattr(config, 'resume_from_checkpoint', None),
        )
        
        # Add framework-specific arguments
        if config.framework_config:
            for key, value in config.framework_config.items():
                if hasattr(training_args, key):
                    setattr(training_args, key, value)
        
        # Choose trainer based on training method
        if config.training_method == "sft":
            trainer = self._create_sft_trainer(
                model, train_dataset, eval_dataset, training_args, config
            )
        elif config.training_method == "ppo":
            trainer = self._create_ppo_trainer(
                model, train_dataset, eval_dataset, training_args, config
            )
        elif config.training_method == "dpo":
            trainer = self._create_dpo_trainer(
                model, train_dataset, eval_dataset, training_args, config
            )
        elif config.training_method == "grpo":
            trainer = self._create_grpo_trainer(
                model, train_dataset, eval_dataset, training_args, config
            )
        else:
            raise ValueError(f"Unsupported training method for TRL: {config.training_method}")
        
        self.trainer = trainer

        # -------------------- diagnostic callbacks -------------------- #
        # try:
        #     from reasoning_frameworks.training.callbacks.token_loss_logger import TokenLossLogger

        #     token_logger = TokenLossLogger(
        #         tokenizer=self.tokenizer,
        #         eval_dataset=eval_dataset,
        #         logging_steps=getattr(config, "logging_steps", 1),
        #         trainer=trainer,
        #     )
        #     trainer.add_callback(token_logger)
        # except Exception as e:  # pragma: no cover
        #     print(f"[TRLAdapter] TokenLossLogger not attached: {e}")

        return trainer
    
    def _convert_to_hf_dataset(self, dataset):
        """Convert PyTorch dataset to HuggingFace dataset format for TRL compatibility."""
        if dataset is None:
            return None
            
        try:
            from datasets import Dataset as HFDataset
            from transformers import ProcessorMixin
            
            # Load scaffold template for consistent prompting
            scaffold_template = self._load_scaffold_template(self.current_config)
            vlm_prompt_template = scaffold_template["vlm_initial_description_prompt"]
            
            # Check if we're using InternVL3 processor that needs special handling
            is_internvl_processor = (
                self.processor is not None and 
                isinstance(self.processor, ProcessorMixin) and
                hasattr(self.processor, 'apply_chat_template')
            )
            
            # Handle both PyTorch datasets AND HuggingFace datasets that need scaffold template application
            if isinstance(dataset, HFDataset):
                # Convert HF dataset to list, apply templates, then back to HF dataset
                dataset_list = dataset.to_list()
                
                # Apply scaffold template to each sample
                samples = []
                for i, sample in enumerate(dataset_list):
                    updated_sample = self._apply_scaffold_template_to_sample(
                        sample, vlm_prompt_template, is_internvl_processor
                    )
                    # Normalize prompt structure to ensure consistency
                    # updated_sample = self._normalize_prompt_structure(updated_sample)
                    samples.append(updated_sample)
                

                
                hf_dataset = HFDataset.from_list(samples)
                return hf_dataset
            
            # Convert PyTorch dataset to HuggingFace format
            
            # Collect all samples from the PyTorch dataset
            samples = []
            for i in range(len(dataset)):
                sample = dataset[i]
                updated_sample = self._apply_scaffold_template_to_sample(
                    sample, vlm_prompt_template, is_internvl_processor
                )
                samples.append(updated_sample)
            
            # Create HuggingFace dataset
            hf_dataset = HFDataset.from_list(samples)
            print(f"✅ Successfully converted to HuggingFace dataset with {len(hf_dataset)} samples")
            print(f"📝 Applied scaffold template '{self.current_config.prompt_template_name}' VLM prompts for consistent captioner training")
            return hf_dataset
            
        except Exception as e:
            print(f"❌ Error converting dataset to HuggingFace format: {e}")
            import traceback
            traceback.print_exc()
            raise  # Fail fast
    
    def _apply_scaffold_template_to_sample(self, sample, vlm_prompt_template, is_internvl_processor):
        """Apply scaffold template by replacing the user's question text."""
        if not isinstance(sample, dict):
            return sample
        
        # Extract question from the sample
        question = self._extract_question_from_sample(sample)
        if not question:
            return sample
        
        # Format scaffold template with question
        formatted_scaffold_instructions = vlm_prompt_template.format(question=question)
        
        # Get or create prompt messages
        if "prompt" in sample:
            prompt_messages = sample["prompt"]
            # Handle nested list structure
            if isinstance(prompt_messages, list) and len(prompt_messages) > 0 and isinstance(prompt_messages[0], list):
                prompt_messages = prompt_messages[0]
        elif "messages" in sample:
            prompt_messages = sample["messages"]
        else:
            # Create new structure
            prompt_messages = [
                {"role": "system", "content": [{"type": "text", "text": "You are a helpful AI assistant."}]},
                {"role": "user", "content": [{"type": "text", "text": question}]}
            ]
        
        # Apply scaffold template to user message text (replace the question)
        updated_messages = []
        
        for msg in prompt_messages:
            if msg.get("role") == "user":
                # Update user message - replace question text with scaffold template
                content = msg.get("content", [])
                if isinstance(content, list):
                    updated_content = []
                    for item in content:
                        if isinstance(item, dict) and item.get("type") == "text":
                            # Replace the text content with scaffold-formatted instructions
                            updated_content.append({
                                            "type": "text", 
                                "text": formatted_scaffold_instructions,
                                "url": item.get("url")
                                        })
                        else:
                            # Keep non-text items (like images) unchanged
                            updated_content.append(item)
                    
                    updated_messages.append({
                        "role": "user",
                        "content": updated_content
                                    })
                else:
                    # Handle case where content is not a list (shouldn't happen with normalized data)
                    updated_messages.append({
                        "role": "user", 
                        "content": [{"type": "text", "text": formatted_scaffold_instructions}]
                                    })
            else:
                # Keep system and other messages unchanged
                updated_messages.append(msg)
        
        # Create updated sample
        updated_sample = sample.copy()
        updated_sample["prompt"] = updated_messages
        if "question" not in updated_sample:
            updated_sample["question"] = question
        
        return updated_sample
    
    def _extract_question_from_sample(self, sample):
        """Extract question text from various sample formats."""
        # Try explicit question field first
        if "question" in sample:
            return str(sample["question"])
        
        # Try to extract from messages or prompt
        messages = sample.get("messages") or sample.get("prompt", [])
        if isinstance(messages, list) and len(messages) > 0:
            # Handle nested list structure
            if isinstance(messages[0], list):
                messages = messages[0]
            
            # Find user message
            for msg in messages:
                if isinstance(msg, dict) and msg.get("role") == "user":
                    content = msg.get("content", "")
                    return self._extract_text_from_content(content)
        
        return ""
    
    def _extract_text_from_content(self, content):
        """Extract text from content that can be either string or list format."""
        if isinstance(content, str):
            return content
        elif isinstance(content, list):
            # Extract text from structured content
            text_parts = []
            for item in content:
                if isinstance(item, dict) and item.get("type") == "text":
                    text_parts.append(item.get("text", ""))
                elif isinstance(item, str):
                    text_parts.append(item)
            return "\n".join(text_parts)
        else:
            return str(content)
    
    def train(
        self,
        trainer,
        config: TrainingConfig
    ) -> TrainingResult:
        """
        Execute training using TRL trainer.
        
        Args:
            trainer: TRL trainer instance
            config: Training configuration
            
        Returns:
            Training results
        """
        try:
            start_time = time.time()
            
            # Start training
            resume_checkpoint = getattr(config, 'resume_from_checkpoint', None)
            if resume_checkpoint:
                self.logger.info(f"Resuming training from checkpoint: {resume_checkpoint}")
                train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
            else:
                train_result = trainer.train()
            
            training_time = time.time() - start_time
            
            # Extract metrics
            final_loss = train_result.training_loss
            train_losses = []
            eval_metrics = None
            
            # Get training history if available
            if hasattr(trainer.state, 'log_history'):
                for log_entry in trainer.state.log_history:
                    if 'train_loss' in log_entry:
                        train_losses.append(log_entry['train_loss'])
                    if 'eval_loss' in log_entry:
                        if eval_metrics is None:
                            eval_metrics = {'loss': []}
                        eval_metrics['loss'].append(log_entry['eval_loss'])
            
            # Save model
            model_path = Path(config.output_dir) / "final_model"
            trainer.save_model(model_path)
            
            # Save adapter if using PEFT
            adapter_path = None
            if config.use_lora:
                adapter_path = Path(config.output_dir) / "adapter"
                if hasattr(trainer.model, 'save_pretrained'):
                    trainer.model.save_pretrained(adapter_path)
            
            best_eval_metric = None
            if eval_metrics and 'loss' in eval_metrics:
                best_eval_metric = min(eval_metrics['loss'])
            
            return TrainingResult(
                final_loss=final_loss,
                best_eval_metric=best_eval_metric,
                training_time=training_time,
                model_path=str(model_path),
                adapter_path=str(adapter_path) if adapter_path else None,
                train_losses=train_losses,
                eval_metrics=eval_metrics if eval_metrics else None,
                config_used=config,
                framework="trl",
                success=True
            )
            
        except Exception as e:
            import traceback
            print(f"[TRLAdapter] Training error: {e}")
            print(f"[TRLAdapter] Full traceback:")
            traceback.print_exc()
            
            training_time = time.time() - start_time
            
            return TrainingResult(
                final_loss=float('inf'),
                training_time=training_time,
                config_used=config,
                framework="trl",
                success=False,
                error_message=f"{type(e).__name__}: {str(e)}"
            )
    
    def _create_sft_trainer(self, model, train_dataset, eval_dataset, 
                           training_args, config):
        """Create SFT trainer optimized for VLM following TRL's official pattern."""
        try:
            from trl import SFTTrainer
            from datasets import Dataset as HFDataset
            import inspect
            
            # Set VLM-specific training arguments
            training_args.remove_unused_columns = False
            training_args.dataset_kwargs = {"skip_prepare_dataset": True}
            
            # Convert PyTorch datasets to HuggingFace format if needed
            print(f" [DEBUG] Converting train_dataset: {type(train_dataset)}")
            if hasattr(train_dataset, '__len__'):
                print(f"🔍 [DEBUG] Train dataset length: {len(train_dataset)}")
            if hasattr(train_dataset, '__getitem__'):
                try:
                    sample_0 = train_dataset[0]
                    print(f"🔍 [DEBUG] Sample 0 keys: {list(sample_0.keys()) if isinstance(sample_0, dict) else type(sample_0)}")
                    if isinstance(sample_0, dict) and 'messages' in sample_0:
                        print(f"🔍 [DEBUG] Sample 0 messages structure:")
                        for i, msg in enumerate(sample_0['messages']):
                            role = msg.get('role', 'unknown')
                            content = msg.get('content', '')
                            print(f"🔍 [DEBUG]   Message {i} - Role: {role}, Content type: {type(content)}")
                            if isinstance(content, str):
                                print(f"🔍 [DEBUG]   Message {i} - Content preview: {content[:100]}...")
                            elif isinstance(content, list):
                                print(f"🔍 [DEBUG]   Message {i} - Content list length: {len(content)}")
                except Exception as e:
                    print(f"🔍 [DEBUG] Error accessing sample 0: {e}")
            
            hf_train_dataset = self._convert_to_hf_dataset(train_dataset)
            
            print(f"🔍 [DEBUG] After conversion: {type(hf_train_dataset)}")
            if hf_train_dataset is not None and len(hf_train_dataset) > 0:
                try:
                    converted_sample_0 = hf_train_dataset[0]
                    print(f"🔍 [DEBUG] Converted sample 0 keys: {list(converted_sample_0.keys())}")
                    if 'prompt' in converted_sample_0:
                        print(f"🔍 [DEBUG] Converted prompt structure:")
                        for i, msg in enumerate(converted_sample_0['prompt']):
                            role = msg.get('role', 'unknown')
                            content = msg.get('content', '')
                            print(f"🔍 [DEBUG]   Message {i} - Role: {role}, Content type: {type(content)}")
                            if isinstance(content, str):
                                print(f"🔍 [DEBUG]   Message {i} - Content preview: {content[:100]}...")
                            elif isinstance(content, list):
                                print(f"🔍 [DEBUG]   Message {i} - Content list length: {len(content)}")
                except Exception as e:
                    print(f"🔍 [DEBUG] Error accessing converted sample 0: {e}")
            
            hf_eval_dataset = self._convert_to_hf_dataset(eval_dataset) if eval_dataset else None
            
            # Create custom data collator for multimodal data
            data_collator = self._create_multimodal_data_collator(model)
            
            # For TRL, we need to pass the tokenizer, not the full processor
            tokenizer_for_trl = self.tokenizer
            
            # Ensure tokenizer has a pad_token (required by TRL)
            if tokenizer_for_trl.pad_token is None:
                tokenizer_for_trl.pad_token = tokenizer_for_trl.eos_token
                print(f"Set pad_token to eos_token for TRL compatibility: {tokenizer_for_trl.pad_token}")

            # Build SFTTrainer kwargs following official pattern
            trainer_kwargs = {
                "model": model,
                "args": training_args,
                "data_collator": data_collator,
                "train_dataset": hf_train_dataset,
                "eval_dataset": hf_eval_dataset,
                "processing_class": tokenizer_for_trl,
            }
            
            # Workaround: Add missing tensor_parallel attribute if it doesn't exist
            # This is needed because TRL sometimes expects this attribute on models
            if not hasattr(model, 'tensor_parallel'):
                model.tensor_parallel = False
                
            # Create trainer with VLM-compatible settings
            trainer = SFTTrainer(**trainer_kwargs)
            
            return trainer
            
        except ImportError as e:
            print(f"Failed to create TRL SFTTrainer: {e}")
            # Fallback to standard Trainer
            from transformers import Trainer, DataCollatorForLanguageModeling
            
            data_collator = DataCollatorForLanguageModeling(
                tokenizer=self.tokenizer,
                mlm=False,
            )
            
            return Trainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                data_collator=data_collator,
                tokenizer=self.tokenizer,
            )
    
    def _create_ppo_trainer(self, model, train_dataset, eval_dataset,
                           training_args, config):
        """Create PPO trainer."""
        from trl import PPOTrainer, PPOConfig
        
        # Convert TrainingArguments to PPOConfig
        ppo_config = PPOConfig(
            learning_rate=config.learning_rate,
            batch_size=config.batch_size,
            gradient_accumulation_steps=config.gradient_accumulation_steps,
        )
        
        return PPOTrainer(
            config=ppo_config,
            model=model,
            tokenizer=self.tokenizer,
        )
    
    def _create_dpo_trainer(self, model, train_dataset, eval_dataset,
                           training_args, config):
        """Create DPO trainer."""
        from trl import DPOTrainer
        
        # DPO requires a reference model
        ref_model = type(model).from_pretrained(config.training_model_name)
        
        return DPOTrainer(
            model=model,
            ref_model=ref_model,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            args=training_args,
            tokenizer=self.tokenizer,
        )
    
    def _import_reward_function(self, reward_func_ref):
        """Dynamically import reward function from string reference or return if already callable."""
        if callable(reward_func_ref):
            return reward_func_ref
        
        if isinstance(reward_func_ref, str):
            # Import from string like "module.submodule.function_name"
            module_path, func_name = reward_func_ref.rsplit('.', 1)
            try:
                import importlib
                module = importlib.import_module(module_path)
                return getattr(module, func_name)
            except (ImportError, AttributeError) as e:
                raise ImportError(f"Could not import reward function '{reward_func_ref}': {e}")
        
        raise ValueError(f"Reward function must be callable or string reference, got {type(reward_func_ref)}")

    def _create_grpo_trainer(self, model, train_dataset, eval_dataset,
                            training_args, config):
        """Create GRPO trainer using TRL's built-in vLLM server mode."""
        # Import directly from modules to avoid lazy loading issues with editable installs
        from trl.trainer.grpo_trainer import GRPOTrainer
        from trl.trainer.grpo_config import GRPOConfig
        

        
        # Load scaffold template for prompt consistency
        scaffold_template = self._load_scaffold_template(config)
        
        # Create data collator for multimodal/scaffold prompting
        data_collator = self._create_multimodal_data_collator(model, scaffold_template)
        
        # GRPO trainer explicitly sets data_collator=identity and bypasses custom collators
        supports_data_collator = False
        
        # Extract reward function(s) if provided via framework_config
        reward_funcs = None
        if hasattr(config, 'framework_config') and config.framework_config:
            reward_func_ref = config.framework_config.get('reward_function')
            if reward_func_ref:
                reward_funcs = self._import_reward_function(reward_func_ref)
                
                # Get server parameters for reward functions
                server_params = {}
                
                # Check if this is a two-stage or three-stage reward function
                is_two_stage_reward = 'two_stage' in str(reward_func_ref).lower()
                is_three_stage_reward = 'three_stage' in str(reward_func_ref).lower()
                
                # Three-stage rewards need VLM server for clarifying questions, two-stage rewards don't
                if not is_two_stage_reward or is_three_stage_reward:
                    # Build VLM server URL from API base - REQUIRED for pipeline rewards, no fallbacks
                    if hasattr(config, 'vlm_api_base') and config.vlm_api_base:
                        server_params['vlm_server_url'] = config.vlm_api_base
                    elif hasattr(config, 'vlm_server_host') and hasattr(config, 'vlm_server_port'):
                        if not config.vlm_server_host or not config.vlm_server_port:
                            raise ValueError("vlm_server_host and vlm_server_port must both be specified in config")
                        server_params['vlm_server_url'] = f"http://{config.vlm_server_host}:{config.vlm_server_port}/v1"
                    else:
                        raise ValueError(
                            "VLM server configuration missing. Must specify either 'vlm_api_base' "
                            "or both 'vlm_server_host' and 'vlm_server_port' in config."
                        )
                    
                    # Extract VLM parameters for adaptive scaffold - ALL REQUIRED, no fallbacks
                    required_vlm_params = [
                        'vlm_model_name', 'vlm_max_tokens', 'vlm_temperature', 'vlm_top_p', 'vlm_top_k'
                    ]
                    missing_vlm_params = []
                    for param in required_vlm_params:
                        if hasattr(config, param) and getattr(config, param) is not None:
                            server_params[param] = getattr(config, param)
                        else:
                            missing_vlm_params.append(param)
                    
                    if missing_vlm_params:
                        scaffold_name = "three-stage" if is_three_stage_reward else "adaptive"
                        raise ValueError(
                            f"Missing required VLM parameters in config: {missing_vlm_params}. "
                            f"All VLM parameters must be explicitly specified for {scaffold_name} scaffold."
                        )
                else:
                    print(f"🔧 Two-stage reward function detected - skipping VLM server validation")
                
                # Build reasoner server URL from API base or host/port - REQUIRED, no fallbacks
                if hasattr(config, 'reasoner_api_base') and config.reasoner_api_base:
                    server_params['reasoner_server_url'] = config.reasoner_api_base
                elif hasattr(config, 'reasoner_server_host') and hasattr(config, 'reasoner_server_port'):
                    if not config.reasoner_server_host or not config.reasoner_server_port:
                        raise ValueError("reasoner_server_host and reasoner_server_port must both be specified in config")
                    server_params['reasoner_server_url'] = f"http://{config.reasoner_server_host}:{config.reasoner_server_port}/v1"
                elif hasattr(config, 'reasoner_server_url') and config.reasoner_server_url:
                    server_params['reasoner_server_url'] = config.reasoner_server_url
                else:
                    raise ValueError(
                        "Reasoner server configuration missing. Must specify either 'reasoner_api_base', "
                        "'reasoner_server_url', or both 'reasoner_server_host' and 'reasoner_server_port' in config."
                    )
                
                # Extract reasoner parameters - requirements depend on scaffold type
                required_reasoner_params = [
                    'reasoner_model_name', 'reasoner_max_tokens',
                    'reasoner_temperature', 'reasoner_top_p', 'reasoner_top_k'
                ]
                
                # Only pipeline (adaptive) scaffolds need scaffold_max_iterations 
                # Two-stage is single-shot, three-stage has fixed one-decision behavior
                is_pipeline_reward = (
                    'pipeline' in str(config.framework_config.get('reward_function', '')) or
                    (not is_two_stage_reward and not is_three_stage_reward)  # Default to pipeline for backward compatibility
                )
                
                if is_pipeline_reward:
                    required_reasoner_params.append('scaffold_max_iterations')
                
                missing_reasoner_params = []
                for param in required_reasoner_params:
                    if hasattr(config, param) and getattr(config, param) is not None:
                        server_params[param] = getattr(config, param)
                    else:
                        missing_reasoner_params.append(param)
                
                if missing_reasoner_params:
                    if is_three_stage_reward:
                        scaffold_type = "three-stage"
                    elif is_two_stage_reward:
                        scaffold_type = "two-stage"
                    else:
                        scaffold_type = "adaptive"
                    raise ValueError(
                        f"Missing required reasoner parameters in config: {missing_reasoner_params}. "
                        f"All reasoner parameters must be explicitly specified for {scaffold_type} scaffold."
                    )
        
        if reward_funcs is None:
            raise ValueError("GRPO training requires a reward function provided via config.framework_config['reward_function']")
            
        # Wrap reward function to inject server URLs, generation parameters, and parallel config
        # Add parallel reward computation configuration - REQUIRED, no fallbacks
        if not hasattr(config, 'reward_parallel_workers') or config.reward_parallel_workers is None:
            raise ValueError("reward_parallel_workers must be specified in config for parallel reward computation")
        if not hasattr(config, 'reward_enable_parallel') or config.reward_enable_parallel is None:
            raise ValueError("reward_enable_parallel must be specified in config for parallel reward computation")
        
        server_params['reward_parallel_workers'] = config.reward_parallel_workers
        server_params['reward_enable_parallel'] = config.reward_enable_parallel
        
        # Add question_penalty parameter for three-stage reward function
        if is_three_stage_reward:
            if not hasattr(config, 'question_penalty') or config.question_penalty is None:
                raise ValueError(
                    "question_penalty must be specified in config for three-stage reward function. "
                    "Add 'question_penalty: 0.1' to your training configuration."
                )
            server_params['question_penalty'] = config.question_penalty
            print(f"🎯 Three-stage question penalty configured: {config.question_penalty}")
        
        # Add debug_data_dir parameter if specified in config
        if hasattr(config, 'debug_data_dir') and config.debug_data_dir is not None:
            server_params['debug_data_dir'] = config.debug_data_dir
            print(f"💾 Debug data directory configured: {config.debug_data_dir}")
        
        # Add current training step for wandb logging
        if hasattr(config, 'current_step') and config.current_step is not None:
            server_params['current_step'] = config.current_step
        elif hasattr(config, 'step') and config.step is not None:
            server_params['step'] = config.step
        
        # Add sample identification for trajectory grouping
        if hasattr(config, 'sample_ids') and config.sample_ids is not None:
            server_params['sample_ids'] = config.sample_ids
        elif hasattr(config, 'original_indices') and config.original_indices is not None:
            server_params['sample_ids'] = config.original_indices
        elif hasattr(config, 'indices') and config.indices is not None:
            server_params['sample_ids'] = config.indices
        
        # Add batch identification
        if hasattr(config, 'batch_id') and config.batch_id is not None:
            server_params['batch_id'] = config.batch_id
        
        if callable(reward_funcs):
            original_func = reward_funcs
            def wrapped_reward_func(*args, **kwargs):
                kwargs.update(server_params)
                # Pass template name for scaffold consistency
                kwargs['prompt_template_name'] = config.prompt_template_name
                return original_func(*args, **kwargs)
            reward_funcs = wrapped_reward_func
        elif isinstance(reward_funcs, list):
            wrapped_funcs = []
            for func in reward_funcs:
                def wrapped_func(*args, **kwargs):
                    kwargs.update(server_params)
                    kwargs['prompt_template_name'] = config.prompt_template_name
                    return func(*args, **kwargs)
                wrapped_funcs.append(wrapped_func)
            reward_funcs = wrapped_funcs

        # Build vLLM server URL for captioner (TRL will handle this automatically)
        if not hasattr(config, 'vllm_server_host') or not hasattr(config, 'vllm_server_port'):
            raise ValueError("vllm_server_host and vllm_server_port must be specified in config for vLLM mode")
        
        if not config.vllm_server_host or not config.vllm_server_port:
            raise ValueError("vllm_server_host and vllm_server_port cannot be empty")
        
        vllm_server_base_url = f"http://{config.vllm_server_host}:{config.vllm_server_port}"
        
        print(f"🚀 GRPO Training with Scaffold Template: {config.prompt_template_name}")

        # Map common TrainingArguments to GRPOConfig initialisation fields
        grpo_args_kwargs = {
            "output_dir": training_args.output_dir,
            "per_device_train_batch_size": training_args.per_device_train_batch_size,
            "per_device_eval_batch_size": training_args.per_device_eval_batch_size,
            "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
            # Decouple generation-time batch from optimizer accumulation by allowing
            # explicit control of steps_per_generation via config. If not provided,
            # default to gradient_accumulation_steps (TRL's usual behavior).
            "steps_per_generation": getattr(config, "grpo_steps_per_generation", training_args.gradient_accumulation_steps),
            "learning_rate": training_args.learning_rate,
            "num_train_epochs": training_args.num_train_epochs,
            "logging_steps": training_args.logging_steps,
            "save_steps": training_args.save_steps,
            "bf16": getattr(training_args, "bf16", False),
            "fp16": getattr(training_args, "fp16", False),
            "gradient_checkpointing": training_args.gradient_checkpointing,
            # Additional GRPO-specific overrides from framework_config
            "beta": getattr(config, "beta", 0.0),
            "num_iterations": getattr(config, "grpo_num_iterations", 1),
            "epsilon": getattr(config, "epsilon", 0.2),
            "epsilon_high": getattr(config, "epsilon_high", None),
            "delta": getattr(config, "delta", None),
            "loss_type": getattr(config, "loss_type", "bnpo"),
            "scale_rewards": getattr(config, "scale_rewards", True),
            "num_generations": getattr(config, "num_generations", 8),
            "torch_compile": getattr(config, "torch_compile", False),
            "gradient_checkpointing_kwargs": {"use_reentrant": True},
            # Critical for multimodal pixel value alignment - see issue with KL divergence spikes
            "generation_batch_size": getattr(config, "generation_batch_size", None),
            "shuffle_dataset": getattr(config, "shuffle_dataset", True),
            "log_completions": getattr(config, "log_completions", True),
            # Static filtering parameters
            "static_filtering_enabled": getattr(config, "static_filtering_enabled", False),
            "difficulty_0_leakage_percent": getattr(config, "difficulty_0_leakage_percent", 0.0),
            "difficulty_1_leakage_percent": getattr(config, "difficulty_1_leakage_percent", 0.0),
        }
        #     "mask_truncated_completions": getattr(config, "mask_truncated_completions", True),
        #     "use_curriculum_learning": getattr(config, "use_curriculum_learning", True),
        #     "enable_dapo_filtering": getattr(config, "enable_dapo_filtering", True),
        #     "curriculum_binning_strategy": getattr(config, "curriculum_binning_strategy", "equal_width"),
        #     "curriculum_adaptive_thresholds": getattr(config, "curriculum_adaptive_thresholds", True),
        #     "curriculum_bin_edges": getattr(config, "curriculum_bin_edges", []),
        #     "curriculum_success_thresholds": getattr(config, "curriculum_success_thresholds", []),
        #     "curriculum_flat_thresholds": getattr(config, "curriculum_flat_thresholds", []),
        #     "curriculum_num_bins": getattr(config, "curriculum_num_bins", 15),
        #     "curriculum_min_weight": getattr(config, "curriculum_min_weight", 0.02),
        # }

        # TRL vLLM integration parameters
        if getattr(config, 'use_vllm', False):
            grpo_args_kwargs.update({
                "use_vllm": True,
                "vllm_server_base_url": vllm_server_base_url,
                "vllm_mode": getattr(config, "vllm_mode", "server"),
                "vllm_server_timeout": getattr(config, "vllm_server_timeout", 240.0),
                "vllm_gpu_memory_utilization": getattr(config, "vllm_gpu_memory_utilization", 0.3),
                "vllm_tensor_parallel_size": getattr(config, "vllm_tensor_parallel_size", 1),
                "vllm_model_impl": getattr(config, "vllm_model_impl", "vllm"),
                "vllm_enable_sleep_mode": getattr(config, "vllm_enable_sleep_mode", True),
                "vllm_guided_decoding_regex": getattr(config, "vllm_guided_decoding_regex", None),
                "vllm_importance_sampling_correction": getattr(config, "vllm_importance_sampling_correction", True),
                "vllm_importance_sampling_cap": getattr(config, "vllm_importance_sampling_cap", 2.0),
            })
        elif getattr(config, 'use_transformers_paged', False):
            grpo_args_kwargs.update({
                "use_vllm": False,
                "use_transformers_paged": True,
            })
        else:
            grpo_args_kwargs.update({
                "use_vllm": False,
                "use_transformers_paged": False,
            })

        if hasattr(config, 'max_completion_length'):
            if config.max_completion_length == "None":
                grpo_args_kwargs["max_completion_length"] = None
            else:
                grpo_args_kwargs["max_completion_length"] = config.max_completion_length

        # Add checkpoint resumption to GRPO config
        if hasattr(config, 'resume_from_checkpoint') and config.resume_from_checkpoint:
            grpo_args_kwargs["resume_from_checkpoint"] = config.resume_from_checkpoint
        
        grpo_config = GRPOConfig(**grpo_args_kwargs)
        grpo_config.remove_unused_columns = False
        grpo_config.dataset_kwargs = {"skip_prepare_dataset": True}

        # For GRPO with multimodal data, we need to pass the full processor
        from transformers import ProcessorMixin
        if self.processor is not None and isinstance(self.processor, ProcessorMixin):
            processing_class_for_trl = self.processor
            print(f"[GRPO] Using multimodal processor: {type(processing_class_for_trl)}")
        else:
            processing_class_for_trl = self.tokenizer
            print(f"[GRPO] Using tokenizer: {type(processing_class_for_trl)}")
        
        # Ensure the processing class has a pad_token (required by TRL)
        tokenizer_part = processing_class_for_trl.tokenizer if hasattr(processing_class_for_trl, 'tokenizer') else processing_class_for_trl
        if tokenizer_part.pad_token is None:
            tokenizer_part.pad_token = tokenizer_part.eos_token
        
        # Create PEFT config if using LoRA
        peft_config = None
        if config.use_lora:
            from peft import LoraConfig, TaskType
            
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                r=config.lora_r,
                lora_alpha=config.lora_alpha,
                lora_dropout=config.lora_dropout,
                target_modules=self._get_target_modules_for_model(model, config),
                bias="none",
            )
            print(f"📦 PEFT config: r={config.lora_r}, alpha={config.lora_alpha}")
        
        # Convert datasets to HuggingFace format with scaffold template applied
        hf_train_dataset = self._convert_to_hf_dataset(train_dataset)
        hf_eval_dataset = self._convert_to_hf_dataset(eval_dataset) if eval_dataset else None
        
                 # Build kwargs for GRPOTrainer
        trainer_kwargs = {
            "model": model,
             "reward_funcs": reward_funcs,
            "train_dataset": hf_train_dataset,
            "eval_dataset": hf_eval_dataset,
             "args": grpo_config,
            "processing_class": processing_class_for_trl,
            "peft_config": peft_config,
        }
        
        print(f"[TRLAdapter] GRPO trainer bypasses data_collator - using dataset-level prompt formatting")
        
        return GRPOTrainer(**trainer_kwargs)
    
    def _load_scaffold_template(self, config):
        """Load scaffold prompt template for consistent VLM prompting."""
        import yaml
        from pathlib import Path
        
        # Get template name from config - REQUIRED, no defaults
        if not hasattr(config, 'prompt_template_name') or not config.prompt_template_name:
            raise ValueError(
                "prompt_template_name must be specified in config for scaffold-consistent training. "
                "Add 'prompt_template_name: \"adaptive_math_v1\"' to your training config."
            )
        
        template_name = config.prompt_template_name
        
        # Construct path to template file
        current_file = Path(__file__).resolve()
        templates_dir = current_file.parent.parent.parent / "prompts" / "templates"
        template_path = templates_dir / f"{template_name}.yaml"
        
        if not template_path.exists():
            raise FileNotFoundError(
                f"Scaffold template '{template_name}' not found at {template_path}. "
                f"Available templates: {list(templates_dir.glob('*.yaml'))}"
            )
        
        with open(template_path, 'r') as f:
            template = yaml.safe_load(f)
        
        # Validate required fields in template based on scaffold type
        # Three-stage scaffolds need VLM prompts since they can ask clarifying questions
        is_two_stage_scaffold = (
            'two_stage' in template_name.lower() or 
            (hasattr(config, 'scaffold_type') and config.scaffold_type == 'two_stage') or
            (hasattr(config, 'framework_config') and config.framework_config and 
             'two_stage' in str(config.framework_config.get('reward_function', '')))
        )
        
        is_three_stage_scaffold = (
            'three_stage' in template_name.lower() or 
            (hasattr(config, 'scaffold_type') and config.scaffold_type == 'three_stage') or
            (hasattr(config, 'framework_config') and config.framework_config and 
             'three_stage' in str(config.framework_config.get('reward_function', '')))
        )
        
        if not is_two_stage_scaffold and "vlm_initial_description_prompt" not in template:
            raise ValueError(
                f"Template '{template_name}' missing required field 'vlm_initial_description_prompt'. "
                f"This field is required for VLM training consistency."
            )
        
        print(f"📝 Loaded scaffold template: {template_path}")
        return template
    
    def _create_multimodal_data_collator(self, model=None, scaffold_template=None):
        """Create a data collator that preserves dataset fields for GRPO trainer."""
        
        def collate_fn(examples):
            """
            Simple collator that preserves dataset fields for GRPO trainer.
            GRPO trainer handles prompt processing internally, so we just need to preserve
            the fields for the reward function.
            """
            
            # Since GRPO trainer handles prompt processing internally using the dataset's
            # "prompt" field, we just return the examples as-is but preserve important
            # fields for the reward function
            
            # Ensure all examples have the required fields
            for i, example in enumerate(examples):
                if "prompt" not in example:
                    raise ValueError(f"Example {i} missing 'prompt' field required by GRPO trainer")
                if "question" not in example:
                    raise ValueError(f"Example {i} missing 'question' field for reward function")
                if "ground_truth" not in example:
                    print(f"Warning: Example {i} missing 'ground_truth' field for reward function")
            
            # GRPO trainer expects examples to be returned as-is
            # It will handle the prompt processing using processor.apply_chat_template
            return examples
        
        return collate_fn
    
    def _normalize_prompt_structure(self, sample):
        """
        Normalize prompt structure to ensure PyArrow compatibility.
        Ensures system messages are strings and user messages are lists for consistency.
        """
        if not isinstance(sample, dict) or "prompt" not in sample:
            return sample
        
        prompt_messages = sample["prompt"]
        if not isinstance(prompt_messages, list):
            return sample
        
        # Normalize each message to have consistent content structure
        normalized_messages = []
        
        for msg in prompt_messages:
            if not isinstance(msg, dict):
                continue
            
            role = msg.get("role", "")
            content = msg.get("content", "")
            
            # Normalize content - ALL content fields should be lists for consistency
            if isinstance(content, str):
                # Convert string to structured format
                normalized_content = [{"type": "text", "text": content}]
            elif isinstance(content, list):
                # Already structured - ensure all items are properly formatted
                normalized_content = []
                for item in content:
                    if isinstance(item, dict):
                        # Ensure proper dict structure
                        if "type" in item:
                            normalized_content.append(item)
                        else:
                            # Convert malformed dict to text
                            normalized_content.append({"type": "text", "text": str(item)})
                    else:
                        # Convert non-dict items to text
                        normalized_content.append({"type": "text", "text": str(item)})
            else:
                # Other types, convert to text
                normalized_content = [{"type": "text", "text": str(content)}]
            
            normalized_messages.append({
                        "role": role,
                "content": normalized_content
            })
        
        # Update sample with normalized messages
        normalized_sample = sample.copy()
        normalized_sample["prompt"] = normalized_messages
        
        return normalized_sample 