"""
Configuration for 1D Inverse Poisson problem experiments.
"""
from dataclasses import dataclass
from typing import List
from .base_config import BaseConfig


@dataclass
class InverseConfig(BaseConfig):
    """Configuration for 1D Inverse Poisson problem experiments."""
    
    # Experiment settings
    experiment_name: str = "inverse_poisson_rpit"
    
    # Spatial domain
    x_start: float = 0.0
    x_end: float = 1.0
    
    # Data corruption parameters
    corruption_level: float = 0.3      # Percentage of corrupted data
    outlier_std: float = 3.0          # Standard deviation for outliers
    missing_data_ratio: float = 0.1   # Percentage of missing data
    
    # Training data parameters
    num_boundary_points: int = 2       # Boundary condition points
    num_interior_points: int = 20      # Interior observation points
    noise_std: float = 0.1            # Additional noise standard deviation
    
    # Collocation points
    num_collocation_points: int = 1000
    
    # Model architecture (for 1D inverse problem: input=x, output=u)
    output_dim: int = 1
    hidden_layers: List[int] = None
    
    # Training parameters
    num_epochs: int = 10000
    learning_rate: float = 1e-3
    log_frequency: int = 200
    
    # R-PIT specific parameters (higher for robustness)
    lambda_sens: float = 0.5  # Higher sensitivity regularization
    lambda_var: float = 2.0   # Higher variance loss weight
    noise_std: float = 0.2    # Higher noise injection
    
    def __post_init__(self):
        if self.hidden_layers is None:
            self.hidden_layers = [64, 64, 64, 64]  # Smaller network for 1D problem
        super().__post_init__()
    
    def get_spatial_config(self) -> dict:
        """Get spatial domain configuration."""
        return {
            'x_start': self.x_start,
            'x_end': self.x_end,
        }
    
    def get_corruption_config(self) -> dict:
        """Get data corruption configuration."""
        return {
            'corruption_level': self.corruption_level,
            'outlier_std': self.outlier_std,
            'missing_data_ratio': self.missing_data_ratio,
        }
    
    def get_data_config(self) -> dict:
        """Get data generation configuration."""
        return {
            'num_boundary_points': self.num_boundary_points,
            'num_interior_points': self.num_interior_points,
            'noise_std': self.noise_std,
        }