import yaml
import os
from typing import Dict, Any, Optional, Union
import logging
from pathlib import Path

# Set up logging
logger = logging.getLogger(__name__)

# Configuration loader and validator for medical keyword prediction training
class ConfigLoader:
    
    def __init__(self, config_path: Optional[str] = None):
        if config_path is None:
            config_dir = os.path.dirname(os.path.abspath(__file__))
            config_path = os.path.join(config_dir, "training_config.yaml")
        
        self.config_path = config_path
        self.config = self._load_config()
        self._validate_config()
        
        logger.info(f"Configuration loaded from: {config_path}")
    
    # Load configuration from YAML file
    def _load_config(self) -> Dict[str, Any]:
        try:
            with open(self.config_path, 'r', encoding='utf-8') as f:
                config = yaml.safe_load(f)
            
            if config is None:
                raise ValueError("Configuration file is empty")
            
            return config
            
        except FileNotFoundError:
            logger.error(f"Configuration file not found: {self.config_path}")
            raise
        except yaml.YAMLError as e:
            logger.error(f"Error parsing YAML configuration: {e}")
            raise
        except Exception as e:
            logger.error(f"Error loading configuration: {e}")
            raise
    
    # Validate configuration structure and required fields
    def _validate_config(self):
        required_sections = ['model', 'training', 'data', 'checkpointing']
        
        for section in required_sections:
            if section not in self.config:
                raise ValueError(f"Missing required configuration section: {section}")
        
        # Validate model configuration
        model_config = self.config['model']
        required_model_fields = ['model_name', 'max_length', 'device']
        for field in required_model_fields:
            if field not in model_config:
                raise ValueError(f"Missing required model field: {field}")
        
        # Validate training configuration
        training_config = self.config['training']
        required_training_fields = ['learning_rate', 'batch_size', 'num_epochs']
        for field in required_training_fields:
            if field not in training_config:
                raise ValueError(f"Missing required training field: {field}")
        
        pass
        
        logger.info("Configuration validation passed")
    
    # Get configuration value using dot notation
    def get(self, key_path: str, default: Any = None) -> Any:
        keys = key_path.split('.')
        value = self.config
        
        try:
            for key in keys:
                value = value[key]
            return value
        except (KeyError, TypeError):
            if default is not None:
                return default
            raise KeyError(f"Configuration key not found: {key_path}")
    
    # Get model configuration section
    def get_model_config(self) -> Dict[str, Any]:
        return self.config['model']
    
    # Get training configuration section
    def get_training_config(self) -> Dict[str, Any]:
        return self.config['training']
    
    # Get data configuration section
    def get_data_config(self) -> Dict[str, Any]:
        return self.config['data']
    
    # Get checkpointing configuration section
    def get_checkpointing_config(self) -> Dict[str, Any]:
        return self.config['checkpointing']
    
    # Get metrics configuration section
    def get_metrics_config(self) -> Dict[str, Any]:
        return self.config.get('metrics', {'enabled': False})
    
    # Get logging configuration section
    def get_logging_config(self) -> Dict[str, Any]:
        return self.config.get('logging', {'enabled': True, 'level': 'INFO'})
    
    # Check if short training mode is enabled
    def is_short_training(self) -> bool:
        return self.get('short_training.enabled', False)
    
    # Check if full training mode is enabled
    def is_full_training(self) -> bool:
        return self.get('full_training.enabled', False)
    
    # Get effective number of epochs based on training mode
    def get_effective_epochs(self) -> int:
        if self.is_short_training():
            return self.get('short_training.max_epochs', 5)
        elif self.is_full_training():
            return self.get('full_training.max_epochs', 50)
        else:
            return self.get('training.num_epochs', 10)
    
    # Get effective batch limits based on training mode
    def get_effective_batch_limits(self) -> Dict[str, Optional[int]]:
        if self.is_short_training():
            return {
                'max_train_batches': self.get('short_training.max_train_batches', 100),
                'max_val_batches': self.get('short_training.max_val_batches', 25)
            }
        else:
            return {
                'max_train_batches': None,
                'max_val_batches': None
            }
    
    # Update configuration with new values
    def update_config(self, updates: Dict[str, Any]):
        for key_path, value in updates.items():
            keys = key_path.split('.')
            config_section = self.config
            
            for key in keys[:-1]:
                if key not in config_section:
                    config_section[key] = {}
                config_section = config_section[key]
            
            config_section[keys[-1]] = value
        
        logger.info(f"Configuration updated with {len(updates)} changes")
    
    # Save current configuration to file
    def save_config(self, output_path: Optional[str] = None):
        if output_path is None:
            output_path = self.config_path
        
        try:
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            
            with open(output_path, 'w', encoding='utf-8') as f:
                yaml.dump(self.config, f, default_flow_style=False, indent=2)
            
            logger.info(f"Configuration saved to: {output_path}")
            
        except Exception as e:
            logger.error(f"Error saving configuration: {e}")
            raise
    
    # Get a summary of key configuration settings
    def get_summary(self) -> Dict[str, Any]:
        summary = {
            'model_name': self.get('model.model_name'),
            'learning_rate': self.get('training.learning_rate'),
            'batch_size': self.get('training.batch_size'),
            'num_epochs': self.get_effective_epochs(),
            'device': self.get('model.device'),
            'metrics_enabled': self.get('metrics.enabled', False),
            'checkpointing_enabled': self.get('checkpointing.enabled', True),
            'training_mode': 'short' if self.is_short_training() else 'full' if self.is_full_training() else 'standard'
        }
        
        if self.is_short_training():
            batch_limits = self.get_effective_batch_limits()
            summary.update(batch_limits)
        
        return summary
    
    # Print a formatted summary of the configuration
    def print_summary(self):
        summary = self.get_summary()
        
        print("="*60)
        print("TRAINING CONFIGURATION SUMMARY")
        print("="*60)
        
        for key, value in summary.items():
            display_key = key.replace('_', ' ').title()
            print(f"{display_key:25}: {value}")
        
        print("="*60)


# Test configuration loader functionality
def test_config_loader():
    
    print("="*60)
    print("TESTING CONFIGURATION LOADER (TASK 6.1)")
    print("="*60)
    
    try:
        print("Step 1: Loading configuration...")
        config = ConfigLoader()
        print("   Configuration loaded successfully")
        
        print("Step 2: Testing configuration access...")
        model_name = config.get('model.model_name')
        learning_rate = config.get('training.learning_rate')
        batch_size = config.get('training.batch_size')
        
        print(f"   Model: {model_name}")
        print(f"   Learning rate: {learning_rate}")
        print(f"   Batch size: {batch_size}")
        print("   Basic access working")
        
        print("Step 3: Testing section access...")
        model_config = config.get_model_config()
        training_config = config.get_training_config()
        
        print(f"   Model config keys: {len(model_config)}")
        print(f"   Training config keys: {len(training_config)}")
        print("   Section access working")
        
        print("Step 4: Testing training mode detection...")
        is_short = config.is_short_training()
        effective_epochs = config.get_effective_epochs()
        batch_limits = config.get_effective_batch_limits()
        
        print(f"   Short training: {is_short}")
        print(f"   Effective epochs: {effective_epochs}")
        print(f"   Batch limits: {batch_limits}")
        print("   Training mode detection working")
        
        print("Step 5: Testing configuration summary...")
        config.print_summary()
        print("   Configuration summary working")
        
        print(f"\nTASK 6.1 CONFIGURATION SYSTEM COMPLETE!")
        print(f"Configuration file: training_config.yaml")
        print(f"Configuration loader: config_loader.py")
        print(f"All hyperparameters from optimization work captured")
        
        return True
        
    except Exception as e:
        print(f"Configuration loader test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


# Convenience function to load configuration
def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
    loader = ConfigLoader(config_path)
    return loader.config


# Update configuration with command line arguments
def update_config_from_args(config: Dict[str, Any], args) -> int:
    changes = 0
    
    if hasattr(args, 'model_name') and args.model_name:
        config['model']['model_name'] = args.model_name
        changes += 1
    
    if hasattr(args, 'learning_rate') and args.learning_rate:
        config['training']['learning_rate'] = args.learning_rate
        changes += 1
    
    if hasattr(args, 'batch_size') and args.batch_size:
        config['training']['batch_size'] = args.batch_size
        changes += 1
    
    if hasattr(args, 'num_epochs') and args.num_epochs:
        config['training']['num_epochs'] = args.num_epochs
        changes += 1
    
    if hasattr(args, 'device') and args.device:
        config['training']['device'] = args.device
        changes += 1
    
    if hasattr(args, 'metrics') and args.metrics is not None:
        config['training']['metrics_enabled'] = args.metrics
        changes += 1
    
    if hasattr(args, 'checkpointing') and args.checkpointing is not None:
        config['training']['checkpointing_enabled'] = args.checkpointing
        changes += 1
    
    if hasattr(args, 'short_training') and args.short_training:
        config['short_training'] = {'enabled': True}
        changes += 1
    
    return changes


if __name__ == "__main__":
    test_config_loader() 