"""
Configuration for 1D Nonlinear Poisson equation benchmark.
"""

from dataclasses import dataclass
from typing import Tuple, Dict
from configs.base_config import BaseConfig
import numpy as np
import torch


@dataclass
class NonlinearPoissonConfig(BaseConfig):
    """
    Configuration for 1D Nonlinear Poisson equation:
    d/dx[k(u,x) * du/dx] - w = 0,  x ∈ [-1, 1]
    """
    
    # Experiment Metadata
    experiment_name: str = "nonlinear_poisson_1d"
    problem_type: str = "nonlinear_poisson"
    
    # Architecture Settings
    grid_resolution: Tuple[int, ...] = (256,)
    spatial_dim: int = 1
    output_dim: int = 1
    use_white_noise_kernel: bool = False
    
    # Domain and Boundary Conditions
    domain_bounds: Tuple[Tuple[float, float], ...] = ((-1.0, 1.0),)
    boundary_condition_type: str = "dirichlet"
    use_hard_boundary_constraints: bool = True
    
    # PDE Parameters
    num_chebyshev_coeffs: int = 5
    parameter_ranges: Dict[str, Tuple[float, float]] = None
    chebyshev_degree: int = 4
    
    # Training Settings
    num_train_tasks: int = 1_000
    num_val_tasks: int = 200
    num_test_tasks: int = 1_000
    batch_size: int = 1
    num_iterations: int = 20_000
    num_context: int = 50
    num_target: int = 200
    
    # Physics Constraint Settings - CRITICAL CHANGES
    num_collocation_interior: int = 512
    num_collocation_boundary: int = 2
    
    # REDUCED PHYSICS WEIGHTS with warmup
    physics_weight: float = 0.0001  # Start very small
    physics_weight_final: float = 0.01  # Final weight still modest
    physics_warmup_iterations: int = 5000  # Gradual increase
    boundary_weight: float = 0.0
    
    # Numerical Solver Settings
    solver_grid_size: int = 1024
    solver_tolerance: float = 1e-8
    solver_max_iterations: int = 100
    
    # Data Generation Settings
    observation_noise_std: float = 0.01
    cache_dataset: bool = True
    dataset_cache_path: str = "./data/nonlinear_poisson_cache.h5"
    
    # OPTIMIZER SETTINGS - Lower learning rate
    learning_rate: float = 1e-4  # Reduced from default
    lr_min: float = 1e-6
    grad_clip_norm: float = 1.0  # Essential!
    
    def __post_init__(self):
        """Initialize parameter ranges."""
        self.parameter_ranges = {
            f'xi_{i}': (-1.0, 1.0) for i in range(self.num_chebyshev_coeffs)
        }
        self.parameter_ranges['w'] = (1.0, 2.0)
        super().__post_init__()
    
    def get_hard_constraint_function(self):
        """Return D(x) = cos(πx/2) for boundary constraints."""
        def D(x: torch.Tensor) -> torch.Tensor:
            return torch.cos(np.pi * x / 2.0)
        return D
    
    def sample_parameters(self, n_samples: int, device: str = 'cpu') -> Dict[str, torch.Tensor]:
        """Sample PDE parameters from prior distribution."""
        xi = torch.rand(n_samples, self.num_chebyshev_coeffs, device=device) * 2.0 - 1.0
        w = torch.rand(n_samples, 1, device=device) + 1.0
        return {'xi': xi, 'w': w}
    
    def get_chebyshev_basis(self, x: torch.Tensor) -> torch.Tensor:
        """Compute Chebyshev polynomial basis functions."""
        x_clamped = torch.clamp(x, -1.0, 1.0)
        basis = []
        
        T_0 = torch.ones_like(x_clamped)
        T_1 = x_clamped
        
        basis.append(T_0)
        if self.num_chebyshev_coeffs > 1:
            basis.append(T_1)
        
        for i in range(2, self.num_chebyshev_coeffs):
            T_next = 2.0 * x_clamped * basis[-1] - basis[-2]
            basis.append(T_next)
        
        return torch.cat(basis, dim=-1)
    
    def validate(self):
        """Additional validation for Nonlinear Poisson config."""
        super().validate()
        
        assert self.num_chebyshev_coeffs > 0, "num_chebyshev_coeffs must be positive"
        assert self.solver_grid_size > 0, "solver_grid_size must be positive"
        assert self.physics_weight >= 0, "physics_weight must be non-negative"
        assert self.physics_weight_final >= self.physics_weight, "physics_weight_final must be >= physics_weight"
        
        print(f"✓ Nonlinear Poisson config validated")
        print(f"  - Chebyshev coefficients: {self.num_chebyshev_coeffs}")
        print(f"  - Domain: {self.domain_bounds}")
        print(f"  - Grid resolution: {self.grid_resolution}")
        print(f"  - Physics weight: {self.physics_weight} → {self.physics_weight_final} (warmup: {self.physics_warmup_iterations})")