"""
Framework detector for automatic training framework selection.

This module automatically detects available training frameworks (TRL, Unsloth)
and selects the most appropriate one based on configuration and availability.
"""

from typing import Dict, List, Optional, Type
from enum import Enum
import importlib

from .base_trainer import BaseTrainingAdapter, TrainingConfig
from .trl_adapter import TRLTrainingAdapter
from .unsloth_adapter import UnslothTrainingAdapter


class FrameworkPriority(Enum):
    """Framework priority levels for automatic selection."""
    HIGH = 1
    MEDIUM = 2
    LOW = 3


class FrameworkDetector:
    """
    Framework detector and selector for training adapters.
    
    This class automatically detects available training frameworks
    and selects the most appropriate one based on:
    1. User preference (if specified)
    2. Training method compatibility
    3. Framework availability
    4. Default priority order
    """
    
    def __init__(self):
        """Initialize framework detector."""
        self._adapters_registry: Dict[str, Type[BaseTrainingAdapter]] = {
            "trl": TRLTrainingAdapter,
            "unsloth": UnslothTrainingAdapter,
        }
        
        # Framework priorities (higher number = lower priority)
        self._framework_priorities: Dict[str, FrameworkPriority] = {
            "unsloth": FrameworkPriority.HIGH,    # Prefer Unsloth for memory efficiency
            "trl": FrameworkPriority.MEDIUM,      # TRL as fallback
        }
        
        # Training method compatibility
        self._method_compatibility: Dict[str, List[str]] = {
            "sft": ["unsloth", "trl"],           # Both support SFT
            "ppo": ["trl", "unsloth"],           # TRL supports PPO, Unsloth supports PPO
            "dpo": ["trl", "unsloth"],           # TRL supports DPO, Unsloth supports DPO
            "grpo": ["unsloth", "trl"],          # Unsloth supports GRPO, TRL supports GRPO
        }
        
        # Cache for adapter instances
        self._adapter_cache: Dict[str, BaseTrainingAdapter] = {}
        self._availability_cache: Dict[str, bool] = {}
    
    def detect_available_frameworks(self) -> List[str]:
        """
        Detect all available training frameworks.
        
        Returns:
            List of available framework names
        """
        available = []
        
        for framework_name, adapter_class in self._adapters_registry.items():
            if self.is_framework_available(framework_name):
                available.append(framework_name)
        
        return available
    
    def is_framework_available(self, framework_name: str) -> bool:
        """
        Check if a specific framework is available.
        
        Args:
            framework_name: Name of the framework to check
            
        Returns:
            True if framework is available, False otherwise
        """
        if framework_name in self._availability_cache:
            return self._availability_cache[framework_name]
        
        if framework_name not in self._adapters_registry:
            self._availability_cache[framework_name] = False
            return False
        
        try:
            adapter = self._get_adapter_instance(framework_name)
            available = adapter.is_available()
            self._availability_cache[framework_name] = available
            return available
        except Exception as e:
            print(f"Error checking {framework_name} availability: {e}")
            self._availability_cache[framework_name] = False
            return False
    
    def select_framework(
        self,
        config: TrainingConfig,
        preferred_framework: Optional[str] = None
    ) -> str:
        """
        Select the best available framework for the given configuration.
        
        Args:
            config: Training configuration
            preferred_framework: User's preferred framework (optional)
            
        Returns:
            Name of selected framework
            
        Raises:
            RuntimeError: If no compatible framework is available
        """
        # Get compatible frameworks for the training method
        compatible_frameworks = self._method_compatibility.get(
            config.training_method, []
        )
        
        if not compatible_frameworks:
            raise RuntimeError(
                f"No frameworks support training method: {config.training_method}"
            )
        
        # Filter by availability
        available_frameworks = [
            fw for fw in compatible_frameworks 
            if self.is_framework_available(fw)
        ]
        
        if not available_frameworks:
            available_all = self.detect_available_frameworks()
            raise RuntimeError(
                f"No available frameworks support {config.training_method}. "
                f"Available frameworks: {available_all}, "
                f"Compatible frameworks: {compatible_frameworks}"
            )
        
        # If user specified a preference, use it if available and compatible
        if preferred_framework:
            if preferred_framework in available_frameworks:
                return preferred_framework
            else:
                print(f"Warning: Preferred framework '{preferred_framework}' is not "
                      f"available or compatible. Falling back to auto-selection.")
        
        # Select based on priority
        selected = self._select_by_priority(available_frameworks, config)
        
        print(f"Auto-selected framework: {selected} for {config.training_method} training")
        return selected
    
    def get_adapter(
        self,
        framework_name: str,
        config: TrainingConfig
    ) -> BaseTrainingAdapter:
        """
        Get a training adapter instance for the specified framework.
        
        Args:
            framework_name: Name of the framework
            config: Training configuration
            
        Returns:
            Training adapter instance
            
        Raises:
            ValueError: If framework is not available
        """
        if not self.is_framework_available(framework_name):
            available = self.detect_available_frameworks()
            raise ValueError(
                f"Framework '{framework_name}' is not available. "
                f"Available frameworks: {available}"
            )
        
        adapter = self._get_adapter_instance(framework_name)
        
        # Validate configuration for this framework
        errors = adapter.validate_config(config)
        if errors:
            error_msg = "\n".join(f"  - {error}" for error in errors)
            raise ValueError(f"Configuration errors for {framework_name}:\n{error_msg}")
        
        return adapter
    
    def get_auto_adapter(self, config: TrainingConfig) -> BaseTrainingAdapter:
        """
        Automatically select and return the best adapter for the configuration.
        
        Args:
            config: Training configuration
            
        Returns:
            Best available training adapter
        """
        framework_name = self.select_framework(config)
        return self.get_adapter(framework_name, config)
    
    def _get_adapter_instance(self, framework_name: str) -> BaseTrainingAdapter:
        """Get or create adapter instance."""
        if framework_name not in self._adapter_cache:
            if framework_name not in self._adapters_registry:
                raise ValueError(f"Unknown framework: {framework_name}")
            
            adapter_class = self._adapters_registry[framework_name]
            self._adapter_cache[framework_name] = adapter_class()
        
        return self._adapter_cache[framework_name]
    
    def _select_by_priority(
        self, 
        available_frameworks: List[str], 
        config: TrainingConfig
    ) -> str:
        """
        Select framework based on priority and configuration hints.
        
        Args:
            available_frameworks: List of available compatible frameworks
            config: Training configuration
            
        Returns:
            Selected framework name
        """
        # Special logic for different scenarios
        
        # For memory-intensive setups, prefer Unsloth
        if config.max_sequence_length > 4096 or config.batch_size > 8:
            if "unsloth" in available_frameworks:
                return "unsloth"
        
        # For RL methods, TRL is usually better
        if config.training_method in ["ppo", "dpo"]:
            if "trl" in available_frameworks:
                return "trl"
        
        # For GRPO, only Unsloth supports it
        if config.training_method == "grpo":
            if "unsloth" in available_frameworks:
                return "unsloth"
        
        # Default: sort by priority and pick the highest
        framework_priority_list = [
            (fw, self._framework_priorities.get(fw, FrameworkPriority.LOW))
            for fw in available_frameworks
        ]
        
        # Sort by priority (lower enum value = higher priority)
        framework_priority_list.sort(key=lambda x: x[1].value)
        
        return framework_priority_list[0][0]
    
    def get_framework_info(self) -> Dict[str, Dict[str, any]]:
        """
        Get information about all registered frameworks.
        
        Returns:
            Dictionary with framework information
        """
        info = {}
        
        for framework_name in self._adapters_registry:
            try:
                adapter = self._get_adapter_instance(framework_name)
                available = self.is_framework_available(framework_name)
                
                info[framework_name] = {
                    'available': available,
                    'priority': self._framework_priorities.get(
                        framework_name, FrameworkPriority.LOW
                    ).name,
                    'supported_methods': [
                        method for method, frameworks in self._method_compatibility.items()
                        if framework_name in frameworks
                    ],
                    'adapter_class': adapter.__class__.__name__,
                }
                
                if available:
                    # Add memory info if available
                    try:
                        memory_info = adapter.get_memory_info()
                        info[framework_name]['memory_info'] = memory_info
                    except:
                        pass
                
            except Exception as e:
                info[framework_name] = {
                    'available': False,
                    'error': str(e)
                }
        
        return info
    
    def recommend_framework(self, config: TrainingConfig) -> Dict[str, any]:
        """
        Get framework recommendation with reasoning.
        
        Args:
            config: Training configuration
            
        Returns:
            Dictionary with recommendation and reasoning
        """
        try:
            recommended = self.select_framework(config)
            compatible = self._method_compatibility.get(config.training_method, [])
            available = self.detect_available_frameworks()
            
            reasoning = []
            
            # Explain the selection
            if recommended == "unsloth":
                reasoning.append("Unsloth recommended for memory efficiency and speed")
                if config.max_sequence_length > 4096:
                    reasoning.append(f"Long sequences ({config.max_sequence_length}) benefit from Unsloth's optimizations")
                if config.use_lora:
                    reasoning.append("LoRA training is highly optimized in Unsloth")
            elif recommended == "trl":
                reasoning.append("TRL recommended for its comprehensive RL support")
                if config.training_method in ["ppo", "dpo"]:
                    reasoning.append(f"{config.training_method.upper()} is specialized in TRL")
            
            return {
                'recommended': recommended,
                'reasoning': reasoning,
                'compatible_frameworks': compatible,
                'available_frameworks': available,
                'configuration_valid': True
            }
            
        except Exception as e:
            return {
                'recommended': None,
                'reasoning': [str(e)],
                'compatible_frameworks': self._method_compatibility.get(config.training_method, []),
                'available_frameworks': self.detect_available_frameworks(),
                'configuration_valid': False,
                'error': str(e)
            }
    
    def register_framework(
        self,
        name: str,
        adapter_class: Type[BaseTrainingAdapter],
        priority: FrameworkPriority = FrameworkPriority.LOW,
        supported_methods: Optional[List[str]] = None
    ):
        """
        Register a new training framework adapter.
        
        Args:
            name: Framework name
            adapter_class: Adapter class
            priority: Framework priority
            supported_methods: List of supported training methods
        """
        self._adapters_registry[name] = adapter_class
        self._framework_priorities[name] = priority
        
        if supported_methods:
            for method in supported_methods:
                if method not in self._method_compatibility:
                    self._method_compatibility[method] = []
                if name not in self._method_compatibility[method]:
                    self._method_compatibility[method].append(name)
        
        # Clear cache for this framework
        if name in self._availability_cache:
            del self._availability_cache[name]
        if name in self._adapter_cache:
            del self._adapter_cache[name]
    
    def clear_cache(self):
        """Clear all cached adapter instances and availability info."""
        self._adapter_cache.clear()
        self._availability_cache.clear()


# Global instance for easy access
framework_detector = FrameworkDetector() 