from dataclasses import dataclass
from typing import Optional
from ConfigSpace import ConfigurationSpace, Float, Integer, Categorical, InCondition

@dataclass
class ModelConfig:
    """Configuration for the RegularizedMLP model."""
    input_size: int
    num_classes: int
    learning_rate: float = 1e-4
    batch_size: int = 128
    use_batch_norm: bool = False
    use_swa: bool = False
    use_weight_decay: bool = False
    weight_decay: float = 1e-4
    use_dropout: bool = False
    dropout_shape: str = 'funnel'
    dropout_rate: float = 0.5
    use_skip: bool = False
    skip_type: str = 'Standard'
    shakedrop_prob: float = 0.5
    augmentation: str = 'None'
    aug_magnitude: float = 0.5
    max_grad_norm: float = 1.0
    use_amp: bool = True

    def validate(self):
        """Validate configuration parameters."""
        if self.dropout_rate < 0 or self.dropout_rate > 1:
            raise ValueError("Dropout rate must be between 0 and 1")
        if self.learning_rate <= 0:
            raise ValueError("Learning rate must be positive")
        if self.batch_size <= 0:
            raise ValueError("Batch size must be positive")
        if self.weight_decay < 0:
            raise ValueError("Weight decay must be non-negative")

def get_cocktail_space() -> ConfigurationSpace:
    """Create configuration space for regularization cocktail."""
    cs = ConfigurationSpace()
    
    # Base hyperparameters
    # learning_rate = Float('learning_rate', (1e-4, 1e-2), log=True, default=1e-3)
    
    # Implicit Methods
    use_batch_norm = Categorical('use_batch_norm', [True, False], default=False)
    use_swa = Categorical('use_swa', [True, False], default=False)
    
    # Weight Decay
    use_weight_decay = Categorical('use_weight_decay', [True, False], default=False)
    weight_decay = Float('weight_decay', (1e-5, 1e-1), log=True, default=1e-4)
    
    # Dropout
    use_dropout = Categorical('use_dropout', [True, False], default=False)
    dropout_shape = Categorical('dropout_shape', [
        'funnel', 'long_funnel', 'diamond', 'triangle'
    ], default='funnel')
    dropout_rate = Float('dropout_rate', (0.0, 0.8), default=0.5)
    
    # Skip Connections
    use_skip = Categorical('use_skip', [True, False], default=False)
    skip_type = Categorical('skip_type', ['Standard', 'ShakeShake', 'ShakeDrop'], default='Standard')
    shakedrop_prob = Float('shakedrop_prob', (0.0, 1.0), default=0.5)
    
    # Data Augmentation
    augmentation = Categorical('augmentation', ['None', 'MixUp'], default='None')
    aug_magnitude = Float('aug_magnitude', (0.0, 1.0), default=0.5)
    
    # Mixed Precision
    use_amp = Categorical('use_amp', [True, False], default=True)
    max_grad_norm = Float('max_grad_norm', (0.1, 10.0), default=1.0)
    
    # Add hyperparameters
    cs.add([
        use_batch_norm, use_swa,
        use_weight_decay, weight_decay,
        use_dropout, dropout_shape, dropout_rate,
        use_skip, skip_type, shakedrop_prob,
        augmentation, aug_magnitude,
        use_amp, max_grad_norm
    ])
    
    # Add conditions
    cs.add([
        InCondition(weight_decay, use_weight_decay, [True]),
        InCondition(dropout_shape, use_dropout, [True]),
        InCondition(dropout_rate, use_dropout, [True]),
        InCondition(skip_type, use_skip, [True]),
        InCondition(shakedrop_prob, skip_type, ['ShakeDrop']),
        InCondition(aug_magnitude, augmentation, ['MixUp'])
    ])
    
    return cs