"""
Base configuration class for R-PIT experiments.
"""
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List
import yaml
import os


@dataclass
class BaseConfig:
    """Base configuration class for all R-PIT experiments."""
    
    # Experiment settings
    experiment_name: str = "rpit_experiment"
    seed: int = 42
    device: str = "cuda"  # Will be "cuda" for GPU experiments
    
    # Model architecture
    hidden_layers: List[int] = field(default_factory=lambda: [128, 128, 128, 128])
    activation: str = "tanh"
    output_dim: int = 1
    
    # Training parameters
    learning_rate: float = 1e-3
    num_epochs: int = 50000
    batch_size: int = 1000
    optimizer: str = "adam"
    
    # R-PIT specific parameters
    lambda_sens: float = 0.1  # Sensitivity regularization weight
    lambda_var: float = 1.0   # Variance loss weight
    noise_std: float = 0.1    # Noise injection standard deviation
    
    # Data parameters
    num_collocation_points: int = 10000
    num_data_points: int = 1000
    num_test_points: int = 2000
    
    # Evaluation parameters
    num_random_seeds: int = 5
    save_frequency: int = 1000
    log_frequency: int = 100
    
    # Paths
    data_dir: str = "data"
    results_dir: str = "data/results"
    checkpoint_dir: str = "checkpoints"
    
    def __post_init__(self):
        """Post-initialization setup."""
        # Create directories if they don't exist
        os.makedirs(self.data_dir, exist_ok=True)
        os.makedirs(self.results_dir, exist_ok=True)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
    
    @classmethod
    def from_yaml(cls, config_path: str) -> 'BaseConfig':
        """Load configuration from YAML file."""
        with open(config_path, 'r') as f:
            config_dict = yaml.safe_load(f)
        return cls(**config_dict)
    
    def to_yaml(self, config_path: str) -> None:
        """Save configuration to YAML file."""
        config_dict = {
            field.name: getattr(self, field.name) 
            for field in self.__dataclass_fields__.values()
        }
        with open(config_path, 'w') as f:
            yaml.dump(config_dict, f, default_flow_style=False)
    
    def update(self, **kwargs) -> None:
        """Update configuration parameters."""
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
            else:
                raise ValueError(f"Unknown configuration parameter: {key}")
    
    def get_model_config(self) -> Dict[str, Any]:
        """Get model-specific configuration."""
        return {
            'hidden_layers': self.hidden_layers,
            'activation': self.activation,
            'output_dim': self.output_dim,
        }
    
    def get_training_config(self) -> Dict[str, Any]:
        """Get training-specific configuration."""
        return {
            'learning_rate': self.learning_rate,
            'num_epochs': self.num_epochs,
            'batch_size': self.batch_size,
            'optimizer': self.optimizer,
        }
    
    def get_rpit_config(self) -> Dict[str, Any]:
        """Get R-PIT specific configuration."""
        return {
            'lambda_sens': self.lambda_sens,
            'lambda_var': self.lambda_var,
            'noise_std': self.noise_std,
        }
