"""
Configuration management for GLEAM-AI.

This module provides dataclass-based configuration classes for managing
model, training, and data parameters with validation and type safety.
"""

from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Union
import yaml
from pathlib import Path
import logging

logger = logging.getLogger(__name__)


@dataclass
class ModelConfig:
    """Configuration for STNP model parameters."""
    
    # Core model dimensions
    x_dim: int
    y_dim: int
    seq_len: int
    
    # Additional dimensions
    xt_dim: int
    in_channels: int
    out_channels: int
    
    # Embedding and encoding dimensions
    embed_out_dim: int
    z_dim: int
    r_dim: int
    
    # RNN parameters
    encoder_num_rnn: int
    decoder_num_rnn: int
    num_rnn: int  # For backward compatibility
    
    # Graph neural network parameters
    max_diffusion_step: int
    num_nodes: int
    
    # Decoder hidden dimensions
    decoder_hidden_dims: List[int] = field(default_factory=lambda: [64, 32])
    hidden_dims: List[int] = field(default_factory=lambda: [64, 32])  # Alias for backward compatibility
    
    # Additional parameters
    context_percentage: float = 0.3
    NUM_COMP: int = 4  # Number of output compartments
    POPULATION_SCALER: float = 1_000_000.0  # Population scaling factor
    
    # Additional parameters for backward compatibility
    extra: Dict[str, Any] = field(default_factory=dict)
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ModelConfig':
        """Create ModelConfig from dictionary."""
        # Extract known parameters
        known_params = {
            'x_dim', 'y_dim', 'seq_len', 'xt_dim', 'in_channels', 'out_channels',
            'embed_out_dim', 'z_dim', 'r_dim', 'encoder_num_rnn', 'decoder_num_rnn',
            'num_rnn', 'max_diffusion_step', 'num_nodes', 'decoder_hidden_dims',
            'hidden_dims', 'context_percentage', 'NUM_COMP', 'POPULATION_SCALER'
        }
        
        model_params = {k: v for k, v in data.items() if k in known_params}
        extra_params = {k: v for k, v in data.items() if k not in known_params}
        
        # Create instance with known parameters
        instance = cls(**model_params)
        # Store extra parameters
        instance.extra = extra_params
        
        return instance
    
    def validate(self) -> bool:
        """Validate configuration parameters."""
        try:
            if self.x_dim <= 0:
                raise ValueError("x_dim must be positive")
            if self.y_dim <= 0:
                raise ValueError("y_dim must be positive")
            if self.seq_len <= 0:
                raise ValueError("seq_len must be positive")
            if self.xt_dim <= 0:
                raise ValueError("xt_dim must be positive")
            if self.in_channels <= 0:
                raise ValueError("in_channels must be positive")
            if self.out_channels <= 0:
                raise ValueError("out_channels must be positive")
            if self.embed_out_dim <= 0:
                raise ValueError("embed_out_dim must be positive")
            if self.z_dim <= 0:
                raise ValueError("z_dim must be positive")
            if self.r_dim <= 0:
                raise ValueError("r_dim must be positive")
            if self.encoder_num_rnn <= 0:
                raise ValueError("encoder_num_rnn must be positive")
            if self.decoder_num_rnn <= 0:
                raise ValueError("decoder_num_rnn must be positive")
            if self.num_rnn <= 0:
                raise ValueError("num_rnn must be positive")
            if self.max_diffusion_step <= 0:
                raise ValueError("max_diffusion_step must be positive")
            if self.num_nodes <= 0:
                raise ValueError("num_nodes must be positive")
            if not self.decoder_hidden_dims or any(dim <= 0 for dim in self.decoder_hidden_dims):
                raise ValueError("decoder_hidden_dims must contain positive values")
            if not self.hidden_dims or any(dim <= 0 for dim in self.hidden_dims):
                raise ValueError("hidden_dims must contain positive values")
            if self.context_percentage <= 0 or self.context_percentage >= 1:
                raise ValueError("context_percentage must be between 0 and 1")
            if self.NUM_COMP <= 0:
                raise ValueError("NUM_COMP must be positive")
            return True
        except ValueError:
            return False
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert configuration to dictionary."""
        return {
            'x_dim': self.x_dim,
            'y_dim': self.y_dim,
            'seq_len': self.seq_len,
            'xt_dim': self.xt_dim,
            'in_channels': self.in_channels,
            'out_channels': self.out_channels,
            'embed_out_dim': self.embed_out_dim,
            'z_dim': self.z_dim,
            'r_dim': self.r_dim,
            'encoder_num_rnn': self.encoder_num_rnn,
            'decoder_num_rnn': self.decoder_num_rnn,
            'num_rnn': self.num_rnn,
            'max_diffusion_step': self.max_diffusion_step,
            'num_nodes': self.num_nodes,
            'decoder_hidden_dims': self.decoder_hidden_dims,
            'hidden_dims': self.hidden_dims,
            'context_percentage': self.context_percentage,
            'NUM_COMP': self.NUM_COMP,
            **self.extra
        }


@dataclass
class TrainingConfig:
    """Configuration for training parameters."""
    
    # Training loop parameters
    max_epochs: int
    lr: float
    lr_encoder: float
    lr_decoder: float
    lr_milestones: List[int]
    lr_gamma: float
    
    # Batch sizes
    train_batch_size: int
    val_batch_size: int
    
    # Early stopping parameters
    patience: int
    min_delta: float = 0.001
    
    # Optimization parameters
    gradient_clip_val: float = 1.0
    weight_decay: float = 0.0
    
    # Device and parallelization
    device: str = "auto"  # "auto", "cpu", "cuda", "mps"
    num_workers: int = 4
    
    # Additional parameters for backward compatibility
    extra: Dict[str, Any] = field(default_factory=dict)
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'TrainingConfig':
        """Create TrainingConfig from dictionary."""
        # Extract known parameters
        known_params = {
            'max_epochs', 'lr', 'lr_encoder', 'lr_decoder', 'lr_milestones', 'lr_gamma',
            'train_batch_size', 'val_batch_size', 'patience', 'min_delta',
            'gradient_clip_val', 'weight_decay', 'device', 'num_workers'
        }
        
        training_params = {k: v for k, v in data.items() if k in known_params}
        extra_params = {k: v for k, v in data.items() if k not in known_params}
        
        # Create instance with known parameters
        instance = cls(**training_params)
        # Store extra parameters
        instance.extra = extra_params
        
        return instance
    
    def validate(self) -> bool:
        """Validate configuration parameters."""
        try:
            if self.max_epochs <= 0:
                raise ValueError("max_epochs must be positive")
            if self.lr <= 0:
                raise ValueError("lr must be positive")
            if self.lr_encoder <= 0:
                raise ValueError("lr_encoder must be positive")
            if self.lr_decoder <= 0:
                raise ValueError("lr_decoder must be positive")
            if self.lr_gamma <= 0:
                raise ValueError("lr_gamma must be positive")
            if self.train_batch_size <= 0:
                raise ValueError("train_batch_size must be positive")
            if self.val_batch_size <= 0:
                raise ValueError("val_batch_size must be positive")
            if self.patience <= 0:
                raise ValueError("patience must be positive")
            if self.min_delta < 0:
                raise ValueError("min_delta must be non-negative")
            if self.gradient_clip_val < 0:
                raise ValueError("gradient_clip_val must be non-negative")
            if self.weight_decay < 0:
                raise ValueError("weight_decay must be non-negative")
            if self.num_workers < 0:
                raise ValueError("num_workers must be non-negative")
            return True
        except ValueError:
            return False
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert configuration to dictionary."""
        return {
            'max_epochs': self.max_epochs,
            'lr': self.lr,
            'lr_encoder': self.lr_encoder,
            'lr_decoder': self.lr_decoder,
            'lr_milestones': self.lr_milestones,
            'lr_gamma': self.lr_gamma,
            'train_batch_size': self.train_batch_size,
            'val_batch_size': self.val_batch_size,
            'patience': self.patience,
            'min_delta': self.min_delta,
            'gradient_clip_val': self.gradient_clip_val,
            'weight_decay': self.weight_decay,
            'device': self.device,
            'num_workers': self.num_workers,
            **self.extra
        }


@dataclass
class DataConfig:
    """Configuration for data parameters."""
    
    # Column names
    x_col_names: List[str]
    frac_pops_names: List[str]
    initial_col_names: List[str]
    output_compartments: List[str]
    
    # Data sizes
    initial_val_size: int
    initial_test_size: int
    
    # Additional parameters for backward compatibility
    extra: Dict[str, Any] = field(default_factory=dict)
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'DataConfig':
        """Create DataConfig from dictionary."""
        # Extract known parameters
        known_params = {
            'x_col_names', 'frac_pops_names', 'initial_col_names', 'output_compartments',
            'initial_val_size', 'initial_test_size'
        }
        
        data_params = {k: v for k, v in data.items() if k in known_params}
        extra_params = {k: v for k, v in data.items() if k not in known_params}
        
        # Create instance with known parameters
        instance = cls(**data_params)
        # Store extra parameters
        instance.extra = extra_params
        
        return instance
    
    def validate(self) -> bool:
        """Validate configuration parameters."""
        try:
            if not self.x_col_names:
                raise ValueError("x_col_names cannot be empty")
            if not self.frac_pops_names:
                raise ValueError("frac_pops_names cannot be empty")
            if not self.initial_col_names:
                raise ValueError("initial_col_names cannot be empty")
            if not self.output_compartments:
                raise ValueError("output_compartments cannot be empty")
            
            if self.initial_val_size <= 0:
                raise ValueError("initial_val_size must be positive")
            if self.initial_test_size <= 0:
                raise ValueError("initial_test_size must be positive")
            return True
        except ValueError:
            return False
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert configuration to dictionary."""
        return {
            'x_col_names': self.x_col_names,
            'frac_pops_names': self.frac_pops_names,
            'initial_col_names': self.initial_col_names,
            'output_compartments': self.output_compartments,
            'initial_val_size': self.initial_val_size,
            'initial_test_size': self.initial_test_size,
            **self.extra
        }


@dataclass
class ActiveLearningConfig:
    """Configuration for active learning parameters."""
    
    # Active learning parameters
    max_iter: int
    initial_train_size: int
    batch_size_stat_compute: int
    retriever_num_workers: int
    pool_num_workers: int
    every_step: int
    lr_gamma: float
    
    # Additional parameters for backward compatibility
    extra: Dict[str, Any] = field(default_factory=dict)
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'ActiveLearningConfig':
        """Create ActiveLearningConfig from dictionary."""
        # Extract known parameters
        known_params = {
            'max_iter', 'initial_train_size', 'batch_size_stat_compute',
            'retriever_num_workers', 'pool_num_workers', 'every_step', 'lr_gamma'
        }
        
        al_params = {k: v for k, v in data.items() if k in known_params}
        extra_params = {k: v for k, v in data.items() if k not in known_params}
        
        # Create instance with known parameters
        instance = cls(**al_params)
        # Store extra parameters
        instance.extra = extra_params
        
        return instance
    
    def validate(self) -> bool:
        """Validate configuration parameters."""
        try:
            if self.max_iter <= 0:
                raise ValueError("max_iter must be positive")
            if self.initial_train_size <= 0:
                raise ValueError("initial_train_size must be positive")
            if self.batch_size_stat_compute <= 0:
                raise ValueError("batch_size_stat_compute must be positive")
            if self.retriever_num_workers < 0:
                raise ValueError("retriever_num_workers must be non-negative")
            if self.pool_num_workers < 0:
                raise ValueError("pool_num_workers must be non-negative")
            if self.every_step <= 0:
                raise ValueError("every_step must be positive")
            if self.lr_gamma <= 0:
                raise ValueError("lr_gamma must be positive")
            return True
        except ValueError:
            return False
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert configuration to dictionary."""
        return {
            'max_iter': self.max_iter,
            'initial_train_size': self.initial_train_size,
            'batch_size_stat_compute': self.batch_size_stat_compute,
            'retriever_num_workers': self.retriever_num_workers,
            'pool_num_workers': self.pool_num_workers,
            'every_step': self.every_step,
            'lr_gamma': self.lr_gamma,
            **self.extra
        }


class Config:
    """Main configuration class for GLEAM-AI."""
    
    def __init__(self, config_path: Union[str, Path, Dict[str, Any]]):
        """
        Initialize configuration from YAML file or dictionary.
        
        Args:
            config_path: Path to YAML file or dictionary with configuration
        """
        self.config_path = config_path
        self._load_config()
        self._validate_all()
    
    def _load_config(self) -> None:
        """Load configuration from YAML file or dictionary."""
        if isinstance(self.config_path, dict):
            data = self.config_path
        else:
            config_path = Path(self.config_path)
            if not config_path.exists():
                raise FileNotFoundError(f"Configuration file not found: {config_path}")
            
            with open(config_path, 'r') as f:
                data = yaml.safe_load(f)
        
        # Load configuration sections
        self.model = ModelConfig.from_dict(data.get('model', {}))
        self.training = TrainingConfig.from_dict(data.get('training', {}))
        self.data = DataConfig.from_dict(data.get('data', {}))
        self.active_learning = ActiveLearningConfig.from_dict(data.get('active_learning', {}))
        
        # Store any additional configuration
        self.extra = {k: v for k, v in data.items() 
                     if k not in ['model', 'training', 'data', 'active_learning']}
    
    def _validate_all(self) -> None:
        """Validate all configuration sections."""
        self.model.validate()
        self.training.validate()
        self.data.validate()
        self.active_learning.validate()
        logger.info("Configuration validation passed")
    
    def save(self, output_path: Union[str, Path]) -> None:
        """
        Save configuration to YAML file.
        
        Args:
            output_path: Path where to save the configuration
        """
        config_dict = {
            'model': self.model.__dict__,
            'training': self.training.__dict__,
            'data': self.data.__dict__,
            'active_learning': self.active_learning.__dict__,
            **self.extra
        }
        
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        with open(output_path, 'w') as f:
            yaml.dump(config_dict, f, default_flow_style=False, indent=2)
        
        logger.info(f"Configuration saved to {output_path}")
    
    def update(self, **kwargs) -> None:
        """
        Update configuration parameters.
        
        Args:
            **kwargs: Configuration parameters to update
        """
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
            else:
                self.extra[key] = value
        
        self._validate_all()
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert configuration to dictionary."""
        return {
            'model': self.model.__dict__,
            'training': self.training.__dict__,
            'data': self.data.__dict__,
            'active_learning': self.active_learning.__dict__,
            **self.extra
        }
    
    def __repr__(self) -> str:
        """String representation of configuration."""
        return f"Config(model={self.model}, training={self.training}, data={self.data})"


def load_config_from_yaml(config_path: Union[str, Path]) -> Dict[str, Any]:
    """
    Load configuration from YAML file.
    
    This function provides backward compatibility with the original
    configuration loading approach.
    
    Args:
        config_path: Path to the YAML configuration file
        
    Returns:
        Dictionary containing the configuration
    """
    config_path = Path(config_path)
    
    if not config_path.exists():
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    logger.info(f"Configuration loaded from: {config_path}")
    return config
