"""
Configuration utilities for training system.

This module provides utilities for loading, validating, and managing
training configurations with support for inheritance, validation,
and environment variable substitution.
"""

import os
import re
import yaml
import logging
from pathlib import Path
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass, field
from datetime import datetime
import copy

logger = logging.getLogger(__name__)


@dataclass
class ConfigValidationError(Exception):
    """Exception raised for configuration validation errors."""
    message: str
    field_path: str
    config_file: Optional[str] = None


@dataclass
class ConfigWarning:
    """Configuration warning."""
    message: str
    field_path: str
    severity: str = "warning"  # "warning", "info", "critical"


class ConfigLoader:
    """Loads and processes YAML configuration files with inheritance."""
    
    def __init__(self, config_dir: Optional[Union[str, Path]] = None):
        """
        Initialize config loader.
        
        Args:
            config_dir: Base directory for configuration files
        """
        self.config_dir = Path(config_dir) if config_dir else Path("configs")
        self.loaded_configs = {}  # Cache for loaded configs
        self.warnings = []
        
    def load_config(self, config_path: Union[str, Path]) -> Dict[str, Any]:
        """
        Load configuration with inheritance support.
        
        Args:
            config_path: Path to configuration file
            
        Returns:
            Loaded and processed configuration
        """
        config_path = Path(config_path)
        
        # Make path absolute if relative
        if not config_path.is_absolute():
            config_path = self.config_dir / config_path
        
        logger.info(f"Loading configuration from {config_path}")
        
        # Check cache
        cache_key = str(config_path.absolute())
        if cache_key in self.loaded_configs:
            logger.debug(f"Using cached config for {config_path}")
            return copy.deepcopy(self.loaded_configs[cache_key])
        
        # Load YAML file
        if not config_path.exists():
            raise FileNotFoundError(f"Configuration file not found: {config_path}")
        
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
        
        if config is None:
            config = {}
        
        # Process inheritance
        config = self._process_inheritance(config, config_path.parent)
        
        # Process environment variables
        config = self._process_environment_variables(config)
        
        # Add metadata
        config = self._add_metadata(config, config_path)
        
        # Cache result
        self.loaded_configs[cache_key] = copy.deepcopy(config)
        
        logger.info(f"Configuration loaded successfully from {config_path}")
        return config
    
    def _process_inheritance(self, config: Dict[str, Any], config_dir: Path) -> Dict[str, Any]:
        """Process base_config inheritance."""
        if "base_config" not in config:
            return config
        
        base_config_path = config["base_config"]
        del config["base_config"]  # Remove from final config
        
        # Load base configuration
        base_path = config_dir / base_config_path
        logger.info(f"Loading base configuration: {base_path}")
        
        base_config = self.load_config(base_path)
        
        # Merge configurations (current config overrides base)
        merged_config = self._deep_merge(base_config, config)
        
        return merged_config
    
    def _deep_merge(self, base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
        """Deep merge two dictionaries."""
        result = copy.deepcopy(base)
        
        for key, value in override.items():
            if key in result and isinstance(result[key], dict) and isinstance(value, dict):
                result[key] = self._deep_merge(result[key], value)
            else:
                result[key] = copy.deepcopy(value)
        
        return result
    
    def _process_environment_variables(self, config: Dict[str, Any]) -> Dict[str, Any]:
        """Process environment variable substitutions."""
        return self._recursive_env_substitute(config)
    
    def _recursive_env_substitute(self, obj: Any) -> Any:
        """Recursively substitute environment variables."""
        if isinstance(obj, dict):
            return {k: self._recursive_env_substitute(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._recursive_env_substitute(item) for item in obj]
        elif isinstance(obj, str):
            return self._substitute_env_vars(obj)
        else:
            return obj
    
    def _substitute_env_vars(self, text: str) -> str:
        """Substitute environment variables in text."""
        # Pattern: ${VAR_NAME} or ${VAR_NAME:default_value}
        pattern = r'\$\{([^}:]+)(?::([^}]*))?\}'
        
        def replace_env(match):
            var_name = match.group(1)
            default_value = match.group(2) if match.group(2) is not None else ""
            return os.environ.get(var_name, default_value)
        
        return re.sub(pattern, replace_env, text)
    
    def _add_metadata(self, config: Dict[str, Any], config_path: Path) -> Dict[str, Any]:
        """Add metadata to configuration."""
        metadata = {
            "config_file": str(config_path),
            "loaded_at": datetime.now().isoformat(),
            "git_commit": self._get_git_commit(),
            "git_branch": self._get_git_branch(),
        }
        
        # Add to existing metadata or create new
        if "_metadata" in config:
            config["_metadata"].update(metadata)
        else:
            config["_metadata"] = metadata
        
        return config
    
    def _get_git_commit(self) -> Optional[str]:
        """Get current git commit hash."""
        try:
            import subprocess
            result = subprocess.run(
                ["git", "rev-parse", "HEAD"],
                capture_output=True,
                text=True,
                timeout=5
            )
            if result.returncode == 0:
                return result.stdout.strip()
        except Exception:
            pass
        return None
    
    def _get_git_branch(self) -> Optional[str]:
        """Get current git branch."""
        try:
            import subprocess
            result = subprocess.run(
                ["git", "rev-parse", "--abbrev-ref", "HEAD"],
                capture_output=True,
                text=True,
                timeout=5
            )
            if result.returncode == 0:
                return result.stdout.strip()
        except Exception:
            pass
        return None


class ConfigValidator:
    """Validates configuration files against schemas and rules."""
    
    def __init__(self):
        self.warnings = []
        self.errors = []
    
    def validate_config(self, config: Dict[str, Any], config_type: str = "training") -> List[ConfigWarning]:
        """
        Validate configuration.
        
        Args:
            config: Configuration to validate
            config_type: Type of configuration ("training", "model", "evaluation")
            
        Returns:
            List of validation warnings
        """
        self.warnings = []
        self.errors = []
        
        if config_type == "training":
            self._validate_training_config(config)
        elif config_type == "model":
            self._validate_model_config(config)
        elif config_type == "evaluation":
            self._validate_evaluation_config(config)
        
        if self.errors:
            error_msg = "; ".join([f"{e.field_path}: {e.message}" for e in self.errors])
            raise ConfigValidationError(f"Configuration validation failed: {error_msg}", "")
        
        return self.warnings
    
    def _validate_training_config(self, config: Dict[str, Any]):
        """Validate training configuration."""
        # Required fields
        required_fields = [
            "training_method",
            "model.vlm_model_name",
            "data.train_data_path",
        ]
        
        for field_path in required_fields:
            if not self._check_field_exists(config, field_path):
                self.errors.append(ConfigValidationError(
                    f"Required field missing: {field_path}",
                    field_path
                ))
        
        # Validate training method
        training_method = self._get_nested_value(config, "training_method")
        valid_methods = ["sft", "ppo", "grpo", "dpo"]
        if training_method and training_method not in valid_methods:
            self.errors.append(ConfigValidationError(
                f"Invalid training method: {training_method}. Valid options: {valid_methods}",
                "training_method"
            ))
        
        # Validate batch sizes
        batch_size = self._get_nested_value(config, "training.per_device_train_batch_size")
        if batch_size and batch_size > 8:
            self.warnings.append(ConfigWarning(
                f"Large batch size ({batch_size}) may cause memory issues",
                "training.per_device_train_batch_size",
                "warning"
            ))
        
        # Validate learning rate
        lr = self._get_nested_value(config, "training.learning_rate")
        if lr and lr > 1e-3:
            self.warnings.append(ConfigWarning(
                f"High learning rate ({lr}) may cause training instability",
                "training.learning_rate",
                "warning"
            ))
        
        # Validate framework compatibility
        self._validate_framework_compatibility(config)
        
        # Validate RL-specific settings
        if training_method in ["ppo", "grpo", "dpo"]:
            self._validate_rl_config(config)
    
    def _validate_rl_config(self, config: Dict[str, Any]):
        """Validate RL-specific configuration."""
        # Check reward model configuration
        if not self._check_field_exists(config, "reward_model.reward_type"):
            self.errors.append(ConfigValidationError(
                "RL training requires reward_model.reward_type",
                "reward_model.reward_type"
            ))
        
        # Validate PPO-specific settings
        training_method = self._get_nested_value(config, "training_method")
        if training_method == "ppo":
            ppo_epochs = self._get_nested_value(config, "training.num_ppo_epochs")
            if ppo_epochs and ppo_epochs > 10:
                self.warnings.append(ConfigWarning(
                    f"High PPO epochs ({ppo_epochs}) may cause overfitting",
                    "training.num_ppo_epochs"
                ))
    
    def _validate_framework_compatibility(self, config: Dict[str, Any]):
        """Validate framework compatibility."""
        framework = self._get_nested_value(config, "framework")
        training_method = self._get_nested_value(config, "training_method")
        
        if framework == "unsloth" and training_method == "ppo":
            self.warnings.append(ConfigWarning(
                "PPO is better supported by TRL framework",
                "framework",
                "warning"
            ))
        
        if framework == "trl" and training_method == "grpo":
            self.warnings.append(ConfigWarning(
                "GRPO is better supported by Unsloth framework",
                "framework",
                "warning"
            ))
    
    def _validate_model_config(self, config: Dict[str, Any]):
        """Validate model configuration."""
        required_fields = [
            "model_name",
            "model_type",
            "architecture"
        ]
        
        for field_path in required_fields:
            if not self._check_field_exists(config, field_path):
                self.errors.append(ConfigValidationError(
                    f"Required field missing: {field_path}",
                    field_path
                ))
    
    def _validate_evaluation_config(self, config: Dict[str, Any]):
        """Validate evaluation configuration."""
        required_fields = [
            "benchmarks",
            "models"
        ]
        
        for field_path in required_fields:
            if not self._check_field_exists(config, field_path):
                self.errors.append(ConfigValidationError(
                    f"Required field missing: {field_path}",
                    field_path
                ))
    
    def _check_field_exists(self, config: Dict[str, Any], field_path: str) -> bool:
        """Check if nested field exists in config."""
        try:
            value = self._get_nested_value(config, field_path)
            return value is not None
        except (KeyError, TypeError):
            return False
    
    def _get_nested_value(self, config: Dict[str, Any], field_path: str) -> Any:
        """Get nested value from config using dot notation."""
        keys = field_path.split(".")
        value = config
        
        for key in keys:
            if isinstance(value, dict) and key in value:
                value = value[key]
            else:
                return None
        
        return value


def load_and_validate_config(
    config_path: Union[str, Path],
    config_type: str = "training",
    config_dir: Optional[Union[str, Path]] = None
) -> Dict[str, Any]:
    """
    Load and validate configuration file.
    
    Args:
        config_path: Path to configuration file
        config_type: Type of configuration
        config_dir: Base directory for configurations
        
    Returns:
        Loaded and validated configuration
    """
    # Load configuration
    loader = ConfigLoader(config_dir)
    config = loader.load_config(config_path)
    
    # Validate configuration
    validator = ConfigValidator()
    warnings = validator.validate_config(config, config_type)
    
    # Log warnings
    for warning in warnings:
        if warning.severity == "critical":
            logger.error(f"Critical warning in {warning.field_path}: {warning.message}")
        elif warning.severity == "warning":
            logger.warning(f"Warning in {warning.field_path}: {warning.message}")
        else:
            logger.info(f"Info for {warning.field_path}: {warning.message}")
    
    return config


def save_config(config: Dict[str, Any], output_path: Union[str, Path]) -> None:
    """
    Save configuration to YAML file.
    
    Args:
        config: Configuration to save
        output_path: Path to save configuration
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Remove metadata before saving
    config_to_save = copy.deepcopy(config)
    if "_metadata" in config_to_save:
        del config_to_save["_metadata"]
    
    with open(output_path, 'w') as f:
        yaml.dump(config_to_save, f, default_flow_style=False, indent=2)
    
    logger.info(f"Configuration saved to {output_path}")


def merge_configs(base_config: Dict[str, Any], override_config: Dict[str, Any]) -> Dict[str, Any]:
    """
    Merge two configurations.
    
    Args:
        base_config: Base configuration
        override_config: Override configuration
        
    Returns:
        Merged configuration
    """
    loader = ConfigLoader()
    return loader._deep_merge(base_config, override_config)


def get_config_value(config: Dict[str, Any], field_path: str, default: Any = None) -> Any:
    """
    Get configuration value using dot notation.
    
    Args:
        config: Configuration dictionary
        field_path: Dot-separated field path
        default: Default value if field not found
        
    Returns:
        Configuration value or default
    """
    validator = ConfigValidator()
    value = validator._get_nested_value(config, field_path)
    return value if value is not None else default


def set_config_value(config: Dict[str, Any], field_path: str, value: Any) -> None:
    """
    Set configuration value using dot notation.
    
    Args:
        config: Configuration dictionary
        field_path: Dot-separated field path
        value: Value to set
    """
    keys = field_path.split(".")
    current = config
    
    # Navigate to parent of target field
    for key in keys[:-1]:
        if key not in current:
            current[key] = {}
        current = current[key]
    
    # Set final value
    current[keys[-1]] = value


def create_config_template(config_type: str) -> Dict[str, Any]:
    """
    Create configuration template.
    
    Args:
        config_type: Type of configuration template to create
        
    Returns:
        Configuration template
    """
    if config_type == "training":
        return {
            "experiment_name": "my_experiment",
            "description": "Training experiment description",
            "training_method": "sft",  # sft, ppo, grpo, dpo
            "framework": None,  # auto-detect
            
            "model": {
                "vlm_model_name": None,  # Must be specified
                "torch_dtype": "bfloat16",
                "device_map": "auto",
                "trust_remote_code": True
            },
            
            "lora": {
                "use_lora": True,
                "r": 16,
                "lora_alpha": 32,
                "lora_dropout": 0.1,
                "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"]
            },
            
            "training": {
                "num_train_epochs": 3,
                "per_device_train_batch_size": 2,
                "learning_rate": 5e-5,
                "max_seq_length": 2048
            },
            
            "data": {
                "train_data_path": None,
                "eval_data_path": None,
                "format_type": "trl_vlm"
            },
            
            "logging": {
                "output_dir": "./outputs",
                "logging_steps": 50,
                "save_steps": 1000
            }
        }
    
    elif config_type == "evaluation":
        return {
            "experiment_name": "my_evaluation",
            "description": "Evaluation experiment description",
            
            "models": {
                "model_path": None,
                "model_type": "vlm"
            },
            
            "benchmarks": {
                "mathvista": {
                    "enabled": True,
                    "data_path": None
                }
            },
            
            "evaluation": {
                "batch_size": 1,
                "max_samples": None
            },
            
            "output": {
                "output_dir": "./evaluation_results",
                "save_details": True
            }
        }
    
    else:
        raise ValueError(f"Unknown config type: {config_type}")


def validate_environment() -> List[str]:
    """
    Validate training environment and return list of issues.
    
    Returns:
        List of environment issues
    """
    issues = []
    
    # Check CUDA availability
    try:
        import torch
        if not torch.cuda.is_available():
            issues.append("CUDA not available - training will be very slow")
        else:
            gpu_count = torch.cuda.device_count()
            if gpu_count == 0:
                issues.append("No GPUs detected")
            else:
                for i in range(gpu_count):
                    memory_gb = torch.cuda.get_device_properties(i).total_memory / 1e9
                    if memory_gb < 12:
                        issues.append(f"GPU {i} has low memory: {memory_gb:.1f}GB")
    except ImportError:
        issues.append("PyTorch not installed")
    
    # Check required packages
    required_packages = ["transformers", "peft", "datasets"]
    for package in required_packages:
        try:
            __import__(package)
        except ImportError:
            issues.append(f"Required package not installed: {package}")
    
    return issues 