"""
Configuration for 2D Burgers equation experiments.
"""
from dataclasses import dataclass
from typing import List
from .base_config import BaseConfig


@dataclass
class BurgersConfig(BaseConfig):
    """Configuration for 2D Burgers equation experiments."""
    
    # Experiment settings
    experiment_name: str = "burgers_2d_rpit"
    
    # Burgers equation parameters
    nu: float = 0.01  # Viscosity coefficient
    
    # Spatial domain
    x_start: float = 0.0
    x_end: float = 1.0
    y_start: float = 0.0
    y_end: float = 1.0
    
    # Temporal domain
    t_start: float = 0.0
    t_end: float = 1.0
    
    # Training data parameters
    num_initial_points: int = 1000    # Initial condition points
    num_boundary_points: int = 400    # Boundary condition points
    num_interior_points: int = 500    # Interior observation points
    noise_std: float = 0.1           # Data noise standard deviation
    
    # Collocation points
    num_collocation_points: int = 10000
    num_x: int = 50
    num_y: int = 50
    num_t: int = 20
    
    # Model architecture (for 2D Burgers: input=(x,y,t), output=(u,v))
    output_dim: int = 2
    hidden_layers: List[int] = None
    
    # Training parameters
    num_epochs: int = 20000
    learning_rate: float = 1e-3
    log_frequency: int = 500
    
    def __post_init__(self):
        if self.hidden_layers is None:
            self.hidden_layers = [128, 128, 128, 128]
        super().__post_init__()
    
    def get_burgers_params(self) -> dict:
        """Get Burgers equation parameters."""
        return {
            'nu': self.nu,
        }
    
    def get_spatial_config(self) -> dict:
        """Get spatial domain configuration."""
        return {
            'x_start': self.x_start,
            'x_end': self.x_end,
            'y_start': self.y_start,
            'y_end': self.y_end,
        }
    
    def get_temporal_config(self) -> dict:
        """Get temporal domain configuration."""
        return {
            't_start': self.t_start,
            't_end': self.t_end,
        }
    
    def get_data_config(self) -> dict:
        """Get data generation configuration."""
        return {
            'num_initial_points': self.num_initial_points,
            'num_boundary_points': self.num_boundary_points,
            'num_interior_points': self.num_interior_points,
            'noise_std': self.noise_std,
        }