"""
Reinforcement Learning (RL) training script for reasoning frameworks.

This script provides training using RL methods like PPO, DPO, and GRPO
for improving reasoning performance through reward signals.
"""

import argparse
import logging
import os
import sys
import yaml
from pathlib import Path
from typing import Dict, Any, Optional, Callable

import torch
from transformers import AutoTokenizer, AutoImageProcessor
from transformers.utils import logging as hf_logging

hf_logging.set_verbosity_error()

# Add framework to path
current_file = Path(__file__).resolve()
framework_root = current_file.parent.parent.parent
sys.path.insert(0, str(framework_root))

# Import framework components
from reasoning_frameworks.training.adapters import (
    BaseTrainingAdapter, TRLTrainingAdapter, UnslothTrainingAdapter
)
from reasoning_frameworks.training.adapters.framework_detector import FrameworkDetector
from reasoning_frameworks.training.adapters.base_trainer import TrainingConfig
from reasoning_frameworks.training.data import (
    ReasoningTrajectoryDataset, ViRL39KDataset, SFTDataset
)


def setup_logging(log_level: str = "INFO") -> None:
    """Set up logging configuration."""
    logging.basicConfig(
        level=getattr(logging, log_level.upper()),
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler('rl_training.log')
        ]
    )

import warnings
warnings.filterwarnings(
    "ignore",
    message=r"Caching is incompatible with gradient checkpointing in Qwen2DecoderLayer.*",
    # category=UserWarning,  # if the message were emitted via warnings.warn
)


def load_config(config_path: str) -> Dict[str, Any]:
    """Load training configuration from YAML file."""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def create_training_config(args: argparse.Namespace, config_dict: Dict[str, Any]) -> TrainingConfig:
    """
    Create TrainingConfig with robust precedence: CLI args > config args > defaults.
    Fail fast on missing required parameters.
    """
    
    # Handle nested training config structure  
    if 'training' in config_dict:
        config_data = config_dict['training'].copy()
    else:
        config_data = config_dict.copy()
    
    # Apply CLI argument precedence (CLI args > config args > defaults)
    cli_overrides = {}
    for arg_name, arg_value in vars(args).items():
        if arg_name in TrainingConfig.__dataclass_fields__:
            # Only override if the CLI argument was explicitly provided
            if _was_cli_arg_provided(args, arg_name, arg_value):
                cli_overrides[arg_name] = arg_value
                print(f"CLI override: {arg_name} = {arg_value}")
    
    # Apply CLI overrides to config data
    config_data.update(cli_overrides)
    
    # Validate required fields before creating TrainingConfig
    _validate_required_rl_config(config_data)
    
    return TrainingConfig.from_dict(config_data)


def _was_cli_arg_provided(args: argparse.Namespace, arg_name: str, arg_value) -> bool:
    """
    Determine if a CLI argument was explicitly provided by the user.
    This helps distinguish between user-provided values and argparse defaults.
    """
    # For boolean store_true actions, only True means it was explicitly provided
    # False means it was NOT provided (argparse default for store_true)
    if arg_name in ['use_lora', 'use_vllm', 'use_transformers_paged', 'cache_implementation']:
        return arg_value is True
    
    # For other arguments, None typically means not provided (if default=None)
    # Non-None values that aren't the argparse default mean they were provided
    if arg_value is not None:
        # Additional heuristics could be added here for other argument types
        return True
    
    return False


def _validate_required_rl_config(config_data: Dict[str, Any]) -> None:
    """
    Validate required configuration for RL training.
    Fail fast and loudly on missing or invalid required parameters.
    """
    required_fields = [
        'training_model_name',
        'reasoner_model_name', 
        'scaffold_type',
        'training_method',
        'dataset_path',
        'prompt_template_name'  # Now required for consistent training
    ]
    
    missing_fields = []
    for field in required_fields:
        if field not in config_data or not config_data[field]:
            missing_fields.append(field)
    
    if missing_fields:
        raise ValueError(
            f"Missing required RL training configuration fields: {missing_fields}. "
            f"These must be specified in the config file."
        )
    
    # Validate training method
    valid_training_methods = ["ppo", "dpo", "grpo"]
    training_method = config_data.get('training_method')
    if training_method not in valid_training_methods:
        raise ValueError(
            f"Invalid training_method '{training_method}'. "
            f"Must be one of: {valid_training_methods}"
        )
    
    # Validate scaffold type
    valid_scaffold_types = ["adaptive", "two_stage", "three_stage"]
    scaffold_type = config_data.get('scaffold_type')
    if scaffold_type not in valid_scaffold_types:
        raise ValueError(
            f"Invalid scaffold_type '{scaffold_type}'. "
            f"Must be one of: {valid_scaffold_types}"
        )
    
    # Validate GRPO-specific requirements
    if training_method == "grpo":
        # Require framework_config with reward_function
        if 'framework_config' not in config_data or not config_data['framework_config']:
            raise ValueError(
                "GRPO training requires framework_config with reward_function specified. "
                "Add: framework_config: { reward_function: 'path.to.your.reward_function' }"
            )
        
        framework_config = config_data['framework_config']
        if 'reward_function' not in framework_config or not framework_config['reward_function']:
            raise ValueError(
                "GRPO training requires framework_config.reward_function to be specified. "
                "Example: reward_function: 'reasoning_frameworks.training.reward_functions.pipeline_math_correctness_reward'"
            )
        
        # Validate generation method configuration
        use_vllm = config_data.get('use_vllm', False)
        use_transformers_paged = config_data.get('use_transformers_paged', False)
        
        if use_vllm and use_transformers_paged:
            raise ValueError(
                "Cannot use both vLLM and transformers paged attention simultaneously. "
                "Choose one: use_vllm: true OR use_transformers_paged: true"
            )
        
        # For pipeline, two-stage and three-stage reward functions, require reasoner server configuration
        reward_func = framework_config['reward_function']
        if 'pipeline' in reward_func or 'two_stage' in reward_func or 'three_stage' in reward_func:
            required_reasoner_params = [
                'reasoner_server_port', 'reasoner_model_name', 'reasoner_max_tokens',
                'reasoner_temperature', 'reasoner_top_p', 'reasoner_top_k'
            ]
            
            # Pipeline and three-stage reward functions need scaffold_max_iterations (two-stage is single-shot)
            if 'pipeline' in reward_func or 'three_stage' in reward_func:
                required_reasoner_params.append('scaffold_max_iterations')
            
            # Three-stage reward function needs question_penalty parameter
            if 'three_stage' in reward_func:
                required_reasoner_params.append('question_penalty')
            
            missing_reasoner_params = []
            for param in required_reasoner_params:
                if param not in config_data or config_data[param] is None:
                    missing_reasoner_params.append(param)
            
            if missing_reasoner_params:
                raise ValueError(
                    f"Reward function requires configuration parameters. "
                    f"Missing parameters: {missing_reasoner_params}"
                )
            
            print(f"✅ Reward function configuration validation passed")
    
    # Cross-validate training method with dataset type for optimal compatibility
    _validate_training_dataset_compatibility(config_data)
    
    print(f"✅ RL configuration validated: {training_method.upper()} training on {scaffold_type} scaffold")


def _validate_training_dataset_compatibility(config_data: Dict[str, Any]) -> None:
    """
    Validate compatibility between training method and dataset type.
    Fail fast on suboptimal combinations that could cause training issues.
    """
    training_method = config_data.get('training_method')
    dataset_type = config_data.get('dataset_type')
    
    # GRPO works best with 'virl' format (prompt + ground truth) or pre-converted 'hf' format
    if training_method == 'grpo' and dataset_type not in ['virl', 'hf']:
        raise ValueError(
            f"GRPO training requires dataset_type='virl' or 'hf' for optimal compatibility. "
            f"Got dataset_type='{dataset_type}'. "
            f"Use prepare_grpo_data.py to format your dataset for GRPO training, "
            f"or convert_grpo_to_hf.py to create pre-converted HuggingFace datasets."
        )
    
    # PPO typically works with 'trajectory' format (full conversations)
    if training_method == 'ppo' and dataset_type not in ['trajectory', 'virl']:
        raise ValueError(
            f"PPO training requires dataset_type='trajectory' or 'virl' for optimal performance. "
            f"Got dataset_type='{dataset_type}'. Use trajectory format for better PPO training."
        )
    
    # DPO requires paired data (typically trajectory format)
    if training_method == 'dpo' and dataset_type != 'trajectory':
        raise ValueError(
            f"DPO training requires dataset_type='trajectory' with preference pairs. "
            f"Got dataset_type='{dataset_type}'."
        )


def _register_lora_adapters_legacy(config: TrainingConfig) -> None:
    """
    Legacy function for manual LoRA adapter registration.
    
    NOTE: This function is no longer needed when using 'trl vllm-serve'.
    TRL automatically handles LoRA adapter registration and weight synchronization.
    This function is kept for reference/debugging purposes only.
    
    Creates actual LoRA adapter files and uploads them to the remote vLLM server.
    """
    import requests
    import requests.exceptions
    import json
    import time
    import tempfile
    import shutil
    from pathlib import Path
    from peft import LoraConfig, get_peft_model
    from transformers import AutoModelForCausalLM, AutoTokenizer
    
    logger = logging.getLogger(__name__)
    
    # Build captioner server URL
    vllm_server_host = getattr(config, 'vllm_server_host', 'localhost')
    vllm_server_port = getattr(config, 'vllm_server_port', 8000)
    server_url = f"http://{vllm_server_host}:{vllm_server_port}"
    
    # Get LoRA parameters from config
    lora_r = getattr(config, 'lora_r', 16)
    lora_alpha = getattr(config, 'lora_alpha', 32)
    lora_dropout = getattr(config, 'lora_dropout', 0.05)
    adapter_name = "caption_lora"  # Standard name for GRPO training
    
    logger.info(f"🔧 Creating and registering LoRA adapter '{adapter_name}' on {server_url}")
    logger.info(f"   Config: r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}")
    
    # Wait for server to be ready
    max_wait = 60
    logger.info(f"⏳ Waiting for vLLM server to be ready...")
    for i in range(max_wait):
        try:
            response = requests.get(f"{server_url}/health", timeout=10)
            if response.status_code == 200:
                logger.info(f"✅ vLLM server is ready!")
                break
        except:
            pass
        
        if i % 10 == 0:
            logger.info(f"   Still waiting... ({i}/{max_wait}s)")
        time.sleep(1)
    else:
        raise RuntimeError(f"vLLM server {server_url} not ready after {max_wait}s")
    
    # Create persistent directory for LoRA adapter (shared filesystem)
    adapter_base_dir = Path("/scratch/<ANONYMIZED>/grpo_lora_adapters")
    adapter_base_dir.mkdir(exist_ok=True)
    adapter_path = adapter_base_dir / adapter_name
    adapter_path.mkdir(exist_ok=True)
    
    try:
        logger.info(f"📦 Creating LoRA adapter files...")
        
        # Create LoRA configuration
        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
            # Target modules for language model (not vision)
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"]
        )
        
        # Save LoRA config
        lora_config.save_pretrained(str(adapter_path))
        logger.info(f"   Saved LoRA config to {adapter_path}")
        
        # List files created so far
        files_created = list(adapter_path.glob("*"))
        logger.info(f"   Files after config save: {[f.name for f in files_created]}")
        
        # Create minimal adapter_model.safetensors file (empty but valid)
        # This creates a properly structured but empty LoRA adapter
        import torch
        from safetensors.torch import save_file
        
        # Create empty tensors for a minimal LoRA adapter
        dummy_tensors = {
            "base_model.model.language_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.zeros(lora_r, 1024),
            "base_model.model.language_model.model.layers.0.self_attn.q_proj.lora_B.weight": torch.zeros(1024, lora_r),
        }
        
        adapter_weights_path = adapter_path / "adapter_model.safetensors"
        save_file(dummy_tensors, str(adapter_weights_path))
        logger.info(f"   Created adapter weights at {adapter_weights_path}")
        
        # Also create traditional .bin format for compatibility
        adapter_bin_path = adapter_path / "adapter_model.bin"
        torch.save(dummy_tensors, str(adapter_bin_path))
        logger.info(f"   Created adapter weights (bin) at {adapter_bin_path}")
        
        # List all files in the adapter directory
        all_files = list(adapter_path.glob("*"))
        logger.info(f"   All files in adapter directory: {[f.name for f in all_files]}")
        
        # Check file sizes
        for f in all_files:
            if f.is_file():
                logger.info(f"     {f.name}: {f.stat().st_size} bytes")
        
        # Wait a moment for files to be visible on shared filesystem
        time.sleep(2)
        
        # Unload any existing adapter to clear cached state
        logger.info(f"🧹 Unloading any existing LoRA adapter...")
        try:
            unload_response = requests.post(
                f"{server_url}/v1/unload_lora_adapter",
                headers={"Content-Type": "application/json"},
                data=json.dumps({"lora_name": adapter_name}),
                timeout=30
            )
            if unload_response.status_code == 200:
                logger.info(f"   Unloaded existing adapter")
            else:
                logger.info(f"   No existing adapter to unload")
        except Exception as e:
            logger.info(f"   No existing adapter to unload: {e}")
        
        # Upload to vLLM server
        logger.info(f"📤 Uploading LoRA adapter to vLLM server...")
        logger.info(f"   Adapter path: {adapter_path}")
        
        upload_payload = {
            "lora_name": adapter_name,
            "lora_path": str(adapter_path)
        }
        
        response = requests.post(
            f"{server_url}/v1/load_lora_adapter",
            headers={"Content-Type": "application/json"},
            data=json.dumps(upload_payload),
            timeout=60
        )
        logger.debug(f"LoRA adapter upload response: {repr(response.text)}")
        
        if response.status_code == 200:
            # Handle both JSON and non-JSON responses from vLLM server
            try:
                if response.text.strip():
                    result = response.json()
                    logger.info(f"✅ LoRA adapter uploaded successfully: {result}")
                else:
                    # Empty response - assume success if status is 200
                    logger.info(f"✅ LoRA adapter '{adapter_name}' uploaded successfully (empty 200 response)")
            except (ValueError, requests.exceptions.JSONDecodeError) as e:
                # Response is not JSON - check if it contains success indicators
                response_text = response.text.strip()
                logger.debug(f"Non-JSON response from vLLM server: {repr(response_text)}")
                
                # For vLLM, a 200 status code typically means success even without JSON
                logger.info(f"✅ LoRA adapter '{adapter_name}' uploaded successfully (non-JSON response)")
            
            # Verify by listing adapters (also handle non-JSON responses)
            try:
                list_response = requests.get(f"{server_url}/v1/list_lora_adapters", timeout=10)
                logger.debug(f"List adapters response: {repr(list_response.text)}")
                if list_response.status_code == 200:
                    try:
                        if list_response.text.strip():
                            adapters = list_response.json()
                            logger.info(f"📋 Current adapters: {adapters}")
                        else:
                            logger.info(f"📋 Adapter list endpoint returned empty response")
                    except (ValueError, requests.exceptions.JSONDecodeError):
                        logger.info(f"📋 Adapter list endpoint returned non-JSON response")
                else:
                    logger.warning(f"Could not list adapters: {list_response.status_code}")
            except Exception as e:
                logger.warning(f"Could not verify adapter registration: {e}")
                
        else:
            raise RuntimeError(f"Failed to upload LoRA adapter: {response.status_code} - {response.text}")
            
    except Exception as e:
        logger.error(f"Error creating LoRA adapter: {e}")
        raise RuntimeError(f"Failed to create/upload LoRA adapter: {e}")
    # Note: Keeping adapter directory persistent for vLLM server access
    # The directory will be reused/overwritten on next run
    
    logger.info(f"🎉 LoRA adapter '{adapter_name}' ready for GRPO training!")


def create_reward_function(reward_type: str = "correctness") -> Callable:
    """DEPRECATED: Create reward function for RL training.
    
    This function is deprecated. Use framework_config.reward_function in YAML instead.
    Keeping for compatibility but it will raise an error if called.
    """
    raise ValueError(
    f"create_reward_function is deprecated. reward_type='{reward_type}' is not supported. "
    f"Use framework_config.reward_function in your YAML configuration instead. "
    f"Example: 'reasoning_frameworks.training.reward_functions.pipeline_math_correctness_reward'"
    )


def load_rl_dataset(config: TrainingConfig, tokenizer: AutoTokenizer, 
                   image_processor = None, split: str = "train"):
    """Load dataset formatted for RL training."""
    dataset_path = config.dataset_path
    
    if not dataset_path:
        raise ValueError("Dataset path must be specified in config")
    
    # Get dataset type - REQUIRED, no fallbacks
    dataset_type = getattr(config, 'dataset_type', None)
    
    if dataset_type is None:
        raise ValueError(
            "dataset_type must be explicitly specified in config. "
            "Valid options: 'trajectory', 'virl', 'sft', 'hf'"
        )
    
    valid_dataset_types = ['trajectory', 'virl', 'sft', 'hf']
    if dataset_type not in valid_dataset_types:
        raise ValueError(
            f"Invalid dataset_type '{dataset_type}'. "
            f"Must be one of: {valid_dataset_types}"
        )
    
    print(f"Loading {dataset_type} dataset for RL training from {dataset_path}")
    
    if dataset_type == 'hf':
        # Load pre-converted HuggingFace dataset (much faster!)
        from datasets import load_from_disk
        
        # For HF datasets, dataset_path should be the directory containing split subdirs
        hf_dataset_path = Path(dataset_path) / f"{split}.hf"
        if not hf_dataset_path.exists():
            raise FileNotFoundError(
                f"Pre-converted HuggingFace dataset not found: {hf_dataset_path}. "
                f"Run convert_grpo_to_hf.py first to convert your dataset."
            )
        
        print(f"Loading pre-converted HuggingFace dataset from {hf_dataset_path}")
        hf_dataset = load_from_disk(str(hf_dataset_path))
        print(f"✅ Loaded {len(hf_dataset)} samples from pre-converted HF dataset")
        return hf_dataset
        
    elif dataset_type == 'trajectory':
        return ReasoningTrajectoryDataset(
            data_path=dataset_path,
            tokenizer=tokenizer,
            max_length=config.max_sequence_length,
            format_type=getattr(config, 'format_type', 'scaffold'),
            image_processor=image_processor,
            split=split,
        )
    elif dataset_type == 'virl':
        return ViRL39KDataset(
            data_path=dataset_path,
            tokenizer=tokenizer,
            max_length=config.max_sequence_length,
            rl_framework="grpo" if config.training_method == "grpo" else "ppo",
            image_processor=image_processor,
            split=split,
        )
    elif dataset_type == 'sft':
        return SFTDataset(
            data_path=dataset_path,
            tokenizer=tokenizer,
            max_length=config.max_sequence_length,
            format_type=getattr(config, 'format_type', 'chatml'),
            image_processor=image_processor,
            split=split,
        )
    else:
        # This should never be reached due to validation above, but fail fast anyway
        raise ValueError(f"Unsupported dataset_type: {dataset_type}")


def get_rl_training_adapter(config: TrainingConfig) -> BaseTrainingAdapter:
    """Get the appropriate RL training adapter."""
    detector = FrameworkDetector()
    
    # Check training method compatibility
    training_method = config.training_method
    if training_method not in ["ppo", "dpo", "grpo"]:
        raise ValueError(f"Training method {training_method} is not supported for RL training")
    
    # Check framework preference
    preferred_framework = getattr(config, 'preferred_framework', None)
    
    if preferred_framework == 'trl':
        adapter = TRLTrainingAdapter()
        if adapter.is_available():
            return adapter
        else:
            print("TRL not available, trying Unsloth...")
    elif preferred_framework == 'unsloth':
        adapter = UnslothTrainingAdapter()
        if adapter.is_available():
            return adapter
        else:
            print("Unsloth not available, trying TRL...")
    
    # Auto-detect, preferring TRL for RL methods
    trl_adapter = TRLTrainingAdapter()
    if trl_adapter.is_available():
        return trl_adapter
    
    unsloth_adapter = UnslothTrainingAdapter()
    if unsloth_adapter.is_available():
        return unsloth_adapter
    
    raise RuntimeError("No RL training framework available. Please install TRL or Unsloth.")


def train_rl(config: TrainingConfig, train_dataset, eval_dataset = None, 
            reward_function: Callable = None) -> None:
    """Execute RL training process."""
    logger = logging.getLogger(__name__)
    
    # Load tokenizer and image processor
    logger.info(f"Loading tokenizer from {config.training_model_name}")
    tokenizer = AutoTokenizer.from_pretrained(config.training_model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load image processor if dealing with VLM
    image_processor = None
    try:
        image_processor = AutoImageProcessor.from_pretrained(config.training_model_name)
        logger.info(f"Loaded image processor from {config.training_model_name}")
    except Exception as e:
        logger.warning(f"Could not load image processor: {e}")
    
    # Get RL training adapter
    logger.info("Setting up RL training adapter...")
    adapter = get_rl_training_adapter(config)
    logger.info(f"Using {adapter.framework_name} for RL training")
    
    # Validate configuration for RL
    errors = adapter.validate_config(config)
    if errors:
        logger.error("Configuration validation failed:")
        for error in errors:
            logger.error(f"  - {error}")
        raise ValueError("Invalid configuration")
    
    # Note: For TRL vllm-serve, LoRA adapters are handled automatically by TRL
    # No manual registration needed when using trl vllm-serve
    
    # Add reward function to config if provided
    if reward_function:
        if not hasattr(config, 'framework_config'):
            config.framework_config = {}
        config.framework_config['reward_function'] = reward_function
    
    # Initialize model
    logger.info("Initializing model for RL training...")
    scaffold = None  # TODO: Load actual scaffold based on config.scaffold_type
    
    model = adapter.initialize_model(config, scaffold)
    logger.info(f"Model initialized with {sum(p.numel() for p in model.parameters())} total parameters")
    logger.info(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    
    # Prepare trainer for RL
    logger.info("Preparing RL trainer...")
    trainer = adapter.prepare_trainer(model, train_dataset, eval_dataset, config)
    
    # Start RL training
    logger.info(f"Starting {config.training_method.upper()} training...")
    result = adapter.train(trainer, config)
    
    if result.success:
        logger.info(f"RL training completed successfully!")
        logger.info(f"Final loss: {result.final_loss:.4f}")
        logger.info(f"Training time: {result.training_time:.2f} seconds")
        if result.model_path:
            logger.info(f"Model saved to: {result.model_path}")
        if result.adapter_path:
            logger.info(f"Adapter saved to: {result.adapter_path}")
    else:
        logger.error(f"RL training failed: {result.error_message}")
        raise RuntimeError(f"RL training failed: {result.error_message}")


def main():
    """Main RL training function."""
    parser = argparse.ArgumentParser(description="Train VLM reasoning components with RL")
    
    # Configuration
    parser.add_argument("--config", type=str, required=True,
                       help="Path to RL training configuration YAML file")
    parser.add_argument("--output_dir", type=str,
                       help="Output directory for training artifacts")
    
    # Model configuration
    parser.add_argument("--training_model_name", type=str,
                       help="VLM model name or path")
    parser.add_argument("--reasoner_model_name", type=str,
                       help="Reasoner model name or path")
    parser.add_argument("--scaffold_type", type=str, choices=["adaptive", "two_stage", "three_stage"],
                       help="Type of reasoning scaffold")
    
    # RL Training parameters
    parser.add_argument("--training_method", type=str, 
                       choices=["ppo", "dpo", "grpo"],
                       help="RL training method")
    parser.add_argument("--learning_rate", type=float,
                       help="Learning rate")
    parser.add_argument("--batch_size", type=int,
                       help="Training batch size")
    parser.add_argument("--max_epochs", type=int,
                       help="Maximum training epochs")
    
    # Reward configuration (DEPRECATED - use config instead)
    parser.add_argument("--reward_type", type=str, 
                       choices=["correctness", "reasoning_quality", "efficiency"],
                       help="DEPRECATED: Use framework_config.reward_function in YAML instead. This argument is ignored.")
    
    # Data configuration
    parser.add_argument("--dataset_path", type=str,
                       help="Path to training dataset")
    parser.add_argument("--eval_dataset_path", type=str,
                       help="Path to evaluation dataset")
    parser.add_argument("--max_sequence_length", type=int,
                       help="Maximum sequence length")
    
    # LoRA configuration
    parser.add_argument("--use_lora", action="store_true",
                       help="Use LoRA adapters")
    parser.add_argument("--lora_r", type=int,
                       help="LoRA rank")
    
    # Framework selection
    parser.add_argument("--framework", type=str, choices=["trl", "unsloth", "auto"],
                       default="auto", help="Training framework preference")
    
    # vLLM server configuration
    parser.add_argument("--vllm_server_host", type=str,
                       help="vLLM server host IP address")
    parser.add_argument("--vllm_server_port", type=int,
                       help="vLLM server port")
    parser.add_argument("--reasoner_server_host", type=str,
                       help="Reasoner server host IP address (for GRPO pipeline and two-stage reward functions)")
    parser.add_argument("--captioner_server_host", type=str,
                       help="Captioner server host IP address (for GRPO pipeline and three-stage reward functions)")

    parser.add_argument("--resume_from_checkpoint", type=str,
                       help="Path to checkpoint to resume from")
    
    # Generation method selection (mutually exclusive)
    generation_group = parser.add_mutually_exclusive_group()
    generation_group.add_argument("--use_vllm", action="store_true",
                                help="Use vLLM for generation acceleration")
    generation_group.add_argument("--use_transformers_paged", action="store_true",
                                help="Use transformers paged attention for generation (no vLLM server needed)")
    
    # Cache implementation for transformers generation
    parser.add_argument("--cache_implementation", type=str,
                       choices=["static", "sliding_window", "offloaded_static"],
                       help="Cache implementation for transformers generation")
    
    # Logging
    parser.add_argument("--log_level", type=str, default="INFO",
                       choices=["DEBUG", "INFO", "WARNING", "ERROR"],
                       help="Logging level")
    
    # Dry run
    parser.add_argument("--dry_run", action="store_true",
                       help="Perform a dry run without actual training")
    
    args = parser.parse_args()
    
    # Set up logging
    setup_logging(args.log_level)
    logger = logging.getLogger(__name__)
    
    # Load configuration
    logger.info(f"Loading RL configuration from {args.config}")
    config_dict = load_config(args.config)
    
    # Create training config
    config = create_training_config(args, config_dict)
    
    # Set framework preference
    if args.framework != "auto":
        config.preferred_framework = args.framework
    
    # Validate RL training method
    if config.training_method not in ["ppo", "dpo", "grpo"]:
        logger.error(f"Invalid training method for RL: {config.training_method}")
        raise ValueError("RL training requires PPO, DPO, or GRPO")
    
    logger.info("RL Training configuration:")
    logger.info(f"  Training Model: {config.training_model_name}")
    logger.info(f"  Reasoner Model: {config.reasoner_model_name}")
    logger.info(f"  Scaffold Type: {config.scaffold_type}")
    logger.info(f"  RL Method: {config.training_method}")
    
    # Check if reward function is specified in config
    reward_function = None
    if hasattr(config, 'framework_config') and config.framework_config:
        reward_func_ref = config.framework_config.get('reward_function')
        if reward_func_ref:
            logger.info(f"  Reward Function: {reward_func_ref} (from config)")
            # Reward function will be imported and configured in TRL adapter
        else:
            raise ValueError(
                "framework_config exists but missing 'reward_function'. "
                "For GRPO training, you must specify: framework_config.reward_function"
            )
    else:
        raise ValueError(
            "framework_config is required for GRPO training but not found in config. "
            "Add framework_config with reward_function to your YAML configuration."
        )
    
    # No CLI fallbacks - all configuration must be explicit in YAML
    if args.reward_type:
        logger.warning(
            f"CLI argument --reward_type={args.reward_type} ignored. "
            f"Using reward function from config: {config.framework_config['reward_function']}"
        )
    
    # Display generation method
    if getattr(config, 'use_vllm', False):
        logger.info(f"  Generation Method: vLLM (server mode)")
    elif getattr(config, 'use_transformers_paged', False):
        logger.info(f"  Generation Method: Transformers Paged Attention")
    else:
        logger.info(f"  Generation Method: Standard Transformers")
    
    # Display prompt template if configured
    if hasattr(config, 'prompt_template_name'):
        logger.info(f"  Prompt Template: {config.prompt_template_name}")
    
    logger.info(f"  Learning Rate: {config.learning_rate}")
    logger.info(f"  Batch Size: {config.batch_size}")
    logger.info(f"  Max Epochs: {config.max_epochs}")
    logger.info(f"  Use LoRA: {config.use_lora}")
    logger.info(f"  Output Dir: {config.output_dir}")
    
    if args.dry_run:
        logger.info("Dry run mode - RL configuration validated successfully")
        return
    
    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)
    
    # Load tokenizer for dataset loading
    tokenizer = AutoTokenizer.from_pretrained(config.training_model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load image processor
    image_processor = None
    try:
        image_processor = AutoImageProcessor.from_pretrained(config.training_model_name)
    except Exception as e:
        logger.warning(f"Could not load image processor: {e}")
    
    # Load datasets
    logger.info("Loading training dataset for RL...")
    train_dataset = load_rl_dataset(config, tokenizer, image_processor, split="train")
    
    eval_dataset = None
    if config.eval_dataset_path:
        logger.info("Loading evaluation dataset...")
        eval_config = TrainingConfig.from_dict(config.to_dict())  # Create a copy
        eval_config.dataset_path = config.eval_dataset_path
        eval_dataset = load_rl_dataset(eval_config, tokenizer, image_processor, split="eval")
    
    # Print dataset statistics
    logger.info(f"Training dataset: {len(train_dataset)} samples")
    if hasattr(train_dataset, 'get_statistics'):
        stats = train_dataset.get_statistics()
        logger.info(f"Dataset statistics: {stats}")
    
    if eval_dataset:
        logger.info(f"Evaluation dataset: {len(eval_dataset)} samples")
    
    # Start RL training (reward function will be handled by TRL adapter from config)
    train_rl(config, train_dataset, eval_dataset, reward_function)
    
    logger.info("RL training completed successfully!")


if __name__ == "__main__":
    main() 