"""
Abstract base class for training adapters.

This module defines the interface that all training framework adapters must implement,
enabling framework-agnostic training of VLM components in reasoning scaffolds.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Union, Iterator
from pathlib import Path
import torch
import torch.nn as nn
from transformers import PreTrainedModel


@dataclass
class TrainingConfig:
    """
    Configuration for training a reasoning scaffold.
    
    This dataclass contains all the necessary configuration for training,
    regardless of the underlying framework (TRL/Unsloth).
    """
    # Model configuration (required fields - no defaults)
    # vlm_model_name: str
    training_model_name: str
    scaffold_type: str  # "adaptive" or "two_stage"
    
    # Training parameters
    learning_rate: float = 5e-6
    batch_size: int = 8
    gradient_accumulation_steps: int = 1
    max_epochs: int = 3
    warmup_ratio: float = 0.05
    weight_decay: float = 0.01
    lr_scheduler_type: str = "cosine_with_restarts"
    max_grad_norm: float = 1.0
    
    # Training type and method
    training_method: str = "sft"  # "sft", "ppo", "dpo", "grpo"
    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.1
    
    # Data configuration
    dataset_path: str = ""
    eval_dataset_path: Optional[str] = None
    dataset_type: str = "auto"  # "auto", "trajectory", "virl", "sft"
    max_sequence_length: int = 2048
    data_collator_type: str = "default"
    
    # Prompt templating
    prompt_template_name: str = "adaptive_math_v1"  # Scaffold prompt template for consistent VLM prompting
    
    # Optimization
    optimizer: str = "adamw"
    scheduler: str = "cosine"
    gradient_checkpointing: bool = True
    mixed_precision: bool = True
    
    # Freezing strategy
    freeze_reasoner: bool = True
    freeze_vlm_layers: Optional[List[str]] = None  # Specific layers to freeze
    
    # Monitoring and output
    logging_steps: int = 10
    save_steps: int = 500
    eval_steps: int = 500
    output_dir: str = "./training_output"
    log_completions: bool = True
    
    # Checkpoint resumption
    resume_from_checkpoint: Optional[str] = None  # Path to checkpoint directory
    
    
    # Experiment tracking
    experiment_name: Optional[str] = None
    run_name: Optional[str] = None
    report_to: Optional[List[str]] = None  # e.g., ["wandb", "tensorboard"]
    tags: Optional[Dict[str, str]] = None
    
    # Framework-specific configs (passed through)
    framework_config: Optional[Dict[str, Any]] = None
    
    # ---------------- RL/GRPO-specific parameters -----------------
    # These fields are optional and only used for GRPO training
    beta: float = 0.0
    grpo_num_iterations: int = 1  # maps to num_iterations in GRPOConfig
    epsilon: float = 0.2
    epsilon_high: Optional[float] = None
    delta: Optional[float] = None
    loss_type: str = "bnpo"
    scale_rewards: bool = True
    num_generations: int = 8
    max_completion_length: Optional[int] = None
    grpo_steps_per_generation: Optional[int] = None
    
    
    # Generation method configuration
    use_vllm: bool = False
    use_transformers_paged: bool = False  # Enable transformers paged attention for generation (alternative to vLLM)
    cache_implementation: Optional[str] = None  # Cache implementation for transformers generation ("static", "sliding_window", etc.)
    vllm_mode: str = "server"  # or "colocate"
    vllm_server_base_url: Optional[str] = None
    vllm_server_host: str = "0.0.0.0"
    vllm_server_port: int = 8000
    vllm_server_timeout: float = 600.0
    vllm_gpu_memory_utilization: float = 0.3
    vllm_tensor_parallel_size: int = 1
    vllm_model_impl: str = "vllm"  # "vllm" or "transformers"
    vllm_enable_sleep_mode: bool = False
    vllm_guided_decoding_regex: Optional[str] = None
    vllm_importance_sampling_correction: bool = True
    vllm_importance_sampling_cap: float = 2.0
    
    # VLM model configuration (for full pipeline)
    vlm_model_type: str = "openai"
    vlm_model_name: Optional[str] = None
    vlm_api_base: Optional[str] = None
    vlm_temperature: Optional[float] = None
    vlm_top_p: Optional[float] = None
    vlm_top_k: Optional[int] = None
    vlm_max_tokens: Optional[int] = None
    vlm_api_key: Optional[str] = None
    
    # Reasoner model configuration (for full pipeline)
    reasoner_model_type: str = "openai"
    reasoner_model_name: Optional[str] = None
    reasoner_api_base: Optional[str] = None
    reasoner_model_name: Optional[str] = None
    reasoner_api_key: Optional[str] = None
    reasoner_max_tokens: Optional[int] = None
    reasoner_temperature: Optional[float] = None
    reasoner_top_p: Optional[float] = None
    reasoner_top_k: Optional[int] = None
    
    # Scaffold-specific configuration for GRPO pipeline
    scaffold_max_iterations: int = 4
    reasoner_server_port: Optional[int] = None  # Separate port for reasoner in full pipeline
    reasoner_server_host: Optional[str] = None  # Reasoner server host
    question_penalty: Optional[float] = None  # Penalty for asking clarifying questions (three-stage scaffold)
    
    # Parallel reward computation configuration
    reward_parallel_workers: int = 32  # Max workers for parallel reward computation
    reward_enable_parallel: bool = True  # Enable parallel reward computation
    use_curriculum_learning: bool = True
    enable_dapo_filtering: bool = True
    curriculum_binning_strategy: str = "equal_width"
    curriculum_adaptive_thresholds: bool = True
    curriculum_bin_edges: List[float] = field(default_factory=lambda: [])
    curriculum_success_thresholds: List[float] = field(default_factory=lambda: [])
    curriculum_flat_thresholds: List[float] = field(default_factory=lambda: [])
    curriculum_num_bins: int = 15
    curriculum_min_weight: float = 0.02
    curriculum_decay_factor: float = 0.2
    curriculum_stats_window: int = 150
    
    # Static filtering configuration
    static_filtering_enabled: bool = False
    difficulty_0_leakage_percent: float = 0.0
    difficulty_1_leakage_percent: float = 0.0
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        result = {}
        for key, value in self.__dict__.items():
            if value is not None:
                result[key] = value
        return result
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'TrainingConfig':
        """Create from dictionary."""
        return cls(**data)


@dataclass
class TrainingResult:
    """
    Result of a training run.
    
    Contains metrics, model artifacts, and metadata from training.
    """
    # Training metrics
    final_loss: float
    best_eval_metric: Optional[float] = None
    training_time: float = 0.0
    
    # Model artifacts
    model_path: Optional[str] = None
    adapter_path: Optional[str] = None  # For LoRA adapters
    
    # Training history
    train_losses: List[float] = None
    eval_metrics: Optional[Dict[str, List[float]]] = None
    
    # Metadata
    config_used: Optional[TrainingConfig] = None
    framework: str = "unknown"
    success: bool = True
    error_message: Optional[str] = None
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        result = {
            'final_loss': self.final_loss,
            'best_eval_metric': self.best_eval_metric,
            'training_time': self.training_time,
            'model_path': self.model_path,
            'adapter_path': self.adapter_path,
            'framework': self.framework,
            'success': self.success,
            'error_message': self.error_message
        }
        
        if self.train_losses:
            result['train_losses'] = self.train_losses
        if self.eval_metrics:
            result['eval_metrics'] = self.eval_metrics
        if self.config_used:
            result['config_used'] = self.config_used.to_dict()
            
        return result


class BaseTrainingAdapter(ABC):
    """
    Abstract base class for training framework adapters.
    
    This class defines the interface that all training adapters must implement,
    enabling framework-agnostic training of VLM components while keeping
    reasoner components frozen.
    
    Key Design Principles:
    1. VLM-only tuning: Only VLM parameters are trainable
    2. Framework agnostic: Same interface for TRL, Unsloth, etc.
    3. Scaffold aware: Understands reasoning scaffold structure
    4. Configuration driven: All behavior controlled by TrainingConfig
    """
    
    def __init__(self, framework_name: str):
        """
        Initialize the training adapter.
        
        Args:
            framework_name: Name of the training framework (e.g., "trl", "unsloth")
        """
        self.framework_name = framework_name
        self.is_initialized = False
        
        # Track training state
        self.current_config: Optional[TrainingConfig] = None
        self.model: Optional[PreTrainedModel] = None
        self.trainer = None
    
    @abstractmethod
    def is_available(self) -> bool:
        """
        Check if this training framework is available in the environment.
        
        Returns:
            True if the framework can be used, False otherwise
        """
        pass
    
    @abstractmethod
    def initialize_model(
        self,
        config: TrainingConfig,
        scaffold: 'BaseReasoningScaffold'
    ) -> PreTrainedModel:
        """
        Initialize the model for training with the given scaffold.
        
        This method should:
        1. Load the VLM and reasoner models
        2. Set up the scaffold structure
        3. Freeze reasoner parameters
        4. Configure LoRA or other adapters if needed
        5. Return the prepared model
        
        Args:
            config: Training configuration
            scaffold: Reasoning scaffold to train
            
        Returns:
            Initialized model ready for training
        """
        pass
    
    @abstractmethod
    def prepare_trainer(
        self,
        model: PreTrainedModel,
        train_dataset,
        eval_dataset,
        config: TrainingConfig
    ):
        """
        Prepare the trainer object for the specific framework.
        
        Args:
            model: Initialized model
            train_dataset: Training dataset
            eval_dataset: Evaluation dataset (optional)
            config: Training configuration
            
        Returns:
            Trainer object ready for training
        """
        pass
    
    @abstractmethod
    def train(
        self,
        trainer,
        config: TrainingConfig
    ) -> TrainingResult:
        """
        Execute the training process.
        
        Args:
            trainer: Prepared trainer object
            config: Training configuration
            
        Returns:
            Training result with metrics and artifacts
        """
        pass
    
    def setup_freezing(
        self,
        model: PreTrainedModel,
        config: TrainingConfig
    ):
        """
        Set up parameter freezing according to configuration.
        
        This is a common operation across frameworks, so we provide
        a default implementation here.
        
        Args:
            model: Model to configure
            config: Training configuration
        """
        if config.freeze_reasoner:
            # Freeze reasoner parameters
            for name, param in model.named_parameters():
                if 'reasoner' in name.lower():
                    param.requires_grad = False
        
        if config.freeze_vlm_layers:
            # Freeze specific VLM layers
            for layer_name in config.freeze_vlm_layers:
                for name, param in model.named_parameters():
                    if layer_name in name:
                        param.requires_grad = False
    
    def get_trainable_parameters(self, model: PreTrainedModel) -> Iterator[nn.Parameter]:
        """
        Get iterator over trainable parameters.
        
        Args:
            model: Model to inspect
            
        Returns:
            Iterator over trainable parameters
        """
        for param in model.parameters():
            if param.requires_grad:
                yield param
    
    def count_parameters(self, model: PreTrainedModel) -> Dict[str, int]:
        """
        Count total and trainable parameters.
        
        Args:
            model: Model to analyze
            
        Returns:
            Dictionary with parameter counts
        """
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        return {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'frozen_parameters': total_params - trainable_params,
            'trainable_percentage': (trainable_params / total_params) * 100 if total_params > 0 else 0
        }
    
    def save_adapter(
        self,
        model: PreTrainedModel,
        output_dir: Union[str, Path],
        config: TrainingConfig
    ) -> str:
        """
        Save adapter weights (LoRA, etc.) if applicable.
        
        This is a default implementation that subclasses can override.
        
        Args:
            model: Trained model
            output_dir: Directory to save to
            config: Training configuration
            
        Returns:
            Path to saved adapter
        """
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        if config.use_lora:
            # Save LoRA adapter
            adapter_path = output_dir / "adapter_model"
            try:
                model.save_pretrained(str(adapter_path))
                return str(adapter_path)
            except Exception as e:
                print(f"Warning: Could not save adapter: {e}")
                return ""
        
        return ""
    
    def load_adapter(
        self,
        model: PreTrainedModel,
        adapter_path: Union[str, Path]
    ) -> PreTrainedModel:
        """
        Load adapter weights into model.
        
        Args:
            model: Base model
            adapter_path: Path to adapter weights
            
        Returns:
            Model with loaded adapter
        """
        # Default implementation - subclasses should override
        return model
    
    def validate_config(self, config: TrainingConfig) -> List[str]:
        """
        Validate training configuration for this framework.
        
        Args:
            config: Configuration to validate
            
        Returns:
            List of validation error messages (empty if valid)
        """
        errors = []
        
        # Basic validation
        if config.learning_rate <= 0:
            errors.append("Learning rate must be positive")
        
        if config.batch_size <= 0:
            errors.append("Batch size must be positive")
        
        if config.max_epochs <= 0:
            errors.append("Max epochs must be positive")
        
        if config.scaffold_type not in ["adaptive", "two_stage", "three_stage"]:
            errors.append(f"Unsupported scaffold type: {config.scaffold_type}")
        
        if config.training_method not in ["sft", "ppo", "dpo", "grpo"]:
            errors.append(f"Unsupported training method: {config.training_method}")
        
        return errors
    
    def get_memory_info(self) -> Dict[str, float]:
        """
        Get current GPU memory usage information.
        
        Returns:
            Dictionary with memory statistics
        """
        if torch.cuda.is_available():
            return {
                'allocated_gb': torch.cuda.memory_allocated() / 1e9,
                'reserved_gb': torch.cuda.memory_reserved() / 1e9,
                'max_allocated_gb': torch.cuda.max_memory_allocated() / 1e9,
            }
        else:
            return {'allocated_gb': 0, 'reserved_gb': 0, 'max_allocated_gb': 0}
    
    def __str__(self) -> str:
        return f"{self.__class__.__name__}(framework={self.framework_name})"
    
    def __repr__(self) -> str:
        return self.__str__() 