"""
Base configuration class for BSNP.
Defines all shared hyperparameters and provides a template for problem-specific configs.
"""

from dataclasses import dataclass, field
from typing import Tuple, Optional, Dict, Any
import torch


@dataclass
class BaseConfig:
    """
    Base configuration for Physics-Informed Convolutional Neural Processes.
    
    All hyperparameters mentioned in the paper are defined here with default values
    matching the shared settings in Appendix D.
    """
    
    # ============================================================================
    # Experiment Metadata
    # ============================================================================
    experiment_name: str = "pi_convnp_base"
    problem_type: str = "base"  # Override in subclasses: 'poisson', 'burgers', 'navier_stokes'
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 42
    
    # ============================================================================
    # Problem Dimensions (task-specific, override as needed)
    # ============================================================================
    spatial_dim: int = 1  # Dimensionality of domain Ω
    observation_dim: int = 1  # Dimension of observations y
    output_dim: int = 1  # d_u in paper (scalar field by default)
    
    # ============================================================================
    # Architecture Settings (Shared across all benchmarks)
    # ============================================================================
    
    # Latent variable dimension (Section 4.1)
    latent_dim: int = 64  # d_z in paper
    
    # Feature embedding dimensions
    observation_encoder_dim: int = 64  # φ_y output dimension
    parameter_encoder_dim: int = 64    # φ_λ output dimension
    latent_encoder_dim: int = 64       # φ_z output dimension
    
    # Grid settings (problem-specific, override in subclasses)
    grid_resolution: Tuple[int, ...] = (256,)  # N_g in paper
    
    # Convolutional backbone (Section 4.2)
    conv_channels: int = 64  # Channel dimension throughout conv blocks
    conv_num_blocks: int = 6  # Number of residual blocks
    conv_kernel_size: int = 3  # Kernel size for convolutions
    conv_activation: str = "swish"  # Activation function
    
    # Kernel settings (Section 4.2, Eq. 4-7)
    kernel_type: str = "rbf"  # Type of kernel κ_ρ
    kernel_lengthscale_init: float = 0.1  # Initial ρ value
    kernel_learnable: bool = True  # Whether ρ is learnable
    
    # Output settings (Eq. 19-20)
    sigma_min: float = 1e-4  # Minimum predictive variance
    epsilon: float = 1e-6  # Numerical stabilizer
    
    # White noise kernel (used in Burgers benchmark)
    use_white_noise_kernel: bool = False
    white_noise_scale_init: float = -5.0  # log-scale initialization
    
    # ============================================================================
    # Training Settings (Shared)
    # ============================================================================
    
    # Optimization
    optimizer: str = "adam"
    learning_rate: float = 1e-3
    lr_min: float = 1e-5  # Minimum LR for cosine decay
    beta1: float = 0.9  # Adam β_1
    beta2: float = 0.999  # Adam β_2
    weight_decay: float = 0.0
    grad_clip_norm: Optional[float] = 1.0  # Gradient clipping
    
    # Training schedule
    num_iterations: int = 200_000  # Total gradient steps
    batch_size: int = 1  # Tasks per gradient step (override for NS)
    
    # Learning rate scheduling
    use_lr_scheduler: bool = True
    scheduler_type: str = "cosine"  # 'cosine', 'step', 'exponential'
    warmup_iterations: int = 1000  # Linear warmup steps
    
    # ============================================================================
    # Data Settings
    # ============================================================================
    
    # Context and target set sizes (per task)
    num_context: int = 50  # N_c in paper
    num_target: int = 200  # N_t in paper
    
    # Observation noise (Eq. 2)
    observation_noise_std: float = 0.0  # σ_n (set to 0 for noiseless case)
    
    # Task sampling
    num_train_tasks: int = 10_000  # Size of training task distribution
    num_val_tasks: int = 1_000
    num_test_tasks: int = 1_000
    
    # ============================================================================
    # Physics Constraint Settings (Section 4.3)
    # ============================================================================
    
    # Collocation points (resampled every iteration, Eq. 14-15)
    num_collocation_interior: int = 1024  # N_r
    num_collocation_boundary: int = 256   # N_∂
    
    # Collocation sampling distribution
    collocation_interior_sampling: str = "uniform"  # p_r distribution
    collocation_boundary_sampling: str = "uniform"  # p_∂ distribution
    
    # Physics loss weights (Eq. 30)
    physics_weight: float = 1.0  # β in paper
    boundary_weight: float = 1.0  # β_∂ in paper
    
    # Boundary condition enforcement
    use_hard_boundary_constraints: bool = True  # Use Eq. 31 parameterization
    
    # Mean-field constraint (Section 4.3, key contribution)
    apply_physics_to_mean_only: bool = True  # Apply G_λ to μ_θ only
    
    # ============================================================================
    # Logging and Checkpointing
    # ============================================================================
    
    # Logging intervals
    log_interval: int = 100  # Log metrics every N iterations
    eval_interval: int = 1000  # Evaluate on validation set every N iterations
    checkpoint_interval: int = 5000  # Save checkpoint every N iterations
    
    # Paths
    log_dir: str = "./logs"
    checkpoint_dir: str = "./checkpoints"
    data_dir: str = "./data"
    
    # Wandb settings (optional)
    use_wandb: bool = False
    wandb_project: str = "BSNP"
    wandb_entity: Optional[str] = None
    
    # Visualization
    num_vis_samples: int = 5  # Number of samples to visualize
    
    # ============================================================================
    # Evaluation Settings
    # ============================================================================
    
    # Uncertainty quantification
    num_samples_nll: int = 100  # MC samples for NLL estimation
    confidence_level: float = 0.9  # For empirical coverage probability (1-α)
    
    # Metrics to compute
    compute_mnse: bool = True
    compute_nll: bool = True
    compute_ecp: bool = True
    
    # ============================================================================
    # Problem-Specific Parameters (Override in subclasses)
    # ============================================================================
    
    # PDE parameter distribution (p(λ) in Eq. 11)
    parameter_ranges: Dict[str, Tuple[float, float]] = field(default_factory=dict)
    
    # Domain specification (Ω)
    domain_bounds: Tuple[Tuple[float, float], ...] = ((-1.0, 1.0),)
    
    # Boundary condition type
    boundary_condition_type: str = "dirichlet"  # 'dirichlet', 'neumann', 'robin', 'periodic'
    
    # ============================================================================
    # Utility Methods
    # ============================================================================
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert config to dictionary for logging."""
        return {k: v for k, v in self.__dict__.items() if not k.startswith('_')}
    
    def update(self, **kwargs):
        """Update config with new values."""
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
            else:
                raise ValueError(f"Config has no attribute '{key}'")
    
    def validate(self):
        """Validate configuration parameters."""
        # Check device availability
        if self.device == "cuda" and not torch.cuda.is_available():
            print("Warning: CUDA not available, falling back to CPU")
            self.device = "cpu"
        
        # Check dimensions
        assert self.spatial_dim > 0, "spatial_dim must be positive"
        assert self.observation_dim > 0, "observation_dim must be positive"
        assert self.output_dim > 0, "output_dim must be positive"
        assert self.latent_dim > 0, "latent_dim must be positive"
        assert self.conv_channels > 0, "conv_channels must be positive"
        assert self.num_context > 0, "num_context must be positive"
        assert self.num_target > 0, "num_target must be positive"
        
        # Check physics weights
        assert self.physics_weight >= 0, "physics_weight must be non-negative"
        assert self.boundary_weight >= 0, "boundary_weight must be non-negative"
        
        # Check collocation points
        assert self.num_collocation_interior > 0, "num_collocation_interior must be positive"
        assert self.num_collocation_boundary > 0, "num_collocation_boundary must be positive"
        
        # Check numerical stability parameters
        assert self.sigma_min > 0, "sigma_min must be positive"
        assert self.epsilon > 0, "epsilon must be positive"
        
        # Check learning rate
        assert 0 < self.learning_rate <= 1, "learning_rate must be in (0, 1]"
        assert 0 < self.lr_min < self.learning_rate, "lr_min must be less than learning_rate"
        
        print(f"✓ Configuration validated successfully for {self.experiment_name}")
        return True
    
    def __post_init__(self):
        """Post-initialization processing."""
        self.validate()
    
    def get_grid_shape(self) -> Tuple[int, ...]:
        """Get the shape of the latent grid."""
        return self.grid_resolution
    
    def get_domain_volume(self) -> float:
        """Compute the volume/area/length of the domain."""
        volume = 1.0
        for lower, upper in self.domain_bounds:
            volume *= (upper - lower)
        return volume