#!/usr/bin/env python3
"""
Bayesian Physics-Informed Neural Network (Bayesian PINN) implementation.

This module provides ensemble-based Bayesian PINN for uncertainty quantification
as a baseline comparison method for the R-PIT framework.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from abc import ABC, abstractmethod

from .base_pinn import BasePINN


class BayesianPINN(BasePINN):
    """
    Ensemble-based Bayesian PINN for uncertainty quantification.
    
    This implementation uses multiple neural networks (ensemble) to approximate
    Bayesian inference and provide uncertainty estimates.
    """
    
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_layers: List[int] = [50, 50, 50],
        activation: str = "tanh",
        n_ensemble: int = 5,
        dropout_rate: float = 0.1,
        weight_decay: float = 1e-4,
        **kwargs
    ):
        """
        Initialize Bayesian PINN with ensemble of networks.
        
        Args:
            input_dim: Input dimension
            output_dim: Output dimension  
            hidden_layers: Hidden layer dimensions
            activation: Activation function
            n_ensemble: Number of networks in ensemble
            dropout_rate: Dropout rate for regularization
            weight_decay: Weight decay for regularization
        """
        # Store ensemble-specific parameters before calling super()
        self.n_ensemble = n_ensemble
        self.dropout_rate = dropout_rate
        self.weight_decay = weight_decay
        
        super().__init__(input_dim, output_dim, hidden_layers, activation, **kwargs)
        
        # Create ensemble of networks
        self.ensemble = nn.ModuleList([
            self._build_network() for _ in range(n_ensemble)
        ])
        
        # Initialize ensemble with different random seeds
        self._initialize_ensemble()
        
        # Move ensemble to the correct device
        self.ensemble.to(self.device)
    
    def to(self, device):
        """Override to method to ensure ensemble is moved to device."""
        super().to(device)
        if hasattr(self, 'ensemble'):
            self.ensemble.to(device)
        return self
    
    def _build_network(self) -> nn.Module:
        """Build a single network in the ensemble."""
        layers = []
        
        # Input layer
        layers.append(nn.Linear(self.input_dim, self.hidden_layers[0]))
        layers.append(self._get_activation())
        
        # Hidden layers
        for i in range(len(self.hidden_layers) - 1):
            layers.append(nn.Linear(self.hidden_layers[i], self.hidden_layers[i + 1]))
            layers.append(self._get_activation())
            layers.append(nn.Dropout(self.dropout_rate))
        
        # Output layer
        layers.append(nn.Linear(self.hidden_layers[-1], self.output_dim))
        
        return nn.Sequential(*layers)
    
    def _initialize_ensemble(self):
        """Initialize ensemble networks with different random seeds."""
        # Use current random state to ensure different initialization across experiments
        current_state = torch.get_rng_state()
        for i, network in enumerate(self.ensemble):
            # Set different seed for each network - use a more robust seed generation
            seed = hash(str(current_state)) + i * 1000 + torch.randint(0, 1000, (1,)).item()
            torch.manual_seed(seed)
            
            for layer in network:
                if isinstance(layer, nn.Linear):
                    # Use more conservative initialization to avoid NaN
                    nn.init.xavier_normal_(layer.weight, gain=0.5)  # Smaller gain
                    nn.init.zeros_(layer.bias)
                    
                    # Additional safety check for NaN weights
                    if torch.isnan(layer.weight).any():
                        print(f"⚠️  NaN detected in layer weights, reinitializing...")
                        nn.init.xavier_normal_(layer.weight, gain=0.1)
                    if torch.isnan(layer.bias).any():
                        print(f"⚠️  NaN detected in layer bias, reinitializing...")
                        nn.init.zeros_(layer.bias)
        
        # Restore original random state
        torch.set_rng_state(current_state)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through ensemble.
        
        Args:
            x: Input tensor
            
        Returns:
            Mean prediction from ensemble
        """
        predictions = []
        for i, network in enumerate(self.ensemble):
            pred = network(x)
            
            # Check for NaN/Inf in individual predictions
            if torch.isnan(pred).any() or torch.isinf(pred).any():
                print(f"⚠️  NaN/Inf detected in ensemble member {i} prediction")
                print(f"   Input range: [{x.min():.6f}, {x.max():.6f}]")
                print(f"   Prediction range: [{pred.min():.6f}, {pred.max():.6f}]")
                # Return safe prediction
                pred = torch.zeros_like(pred)
            
            predictions.append(pred)
        
        # Stack and compute mean
        stacked_preds = torch.stack(predictions, dim=0)
        mean_pred = stacked_preds.mean(dim=0)
        
        # Final check for NaN/Inf
        if torch.isnan(mean_pred).any() or torch.isinf(mean_pred).any():
            print("⚠️  NaN/Inf detected in final mean prediction")
            mean_pred = torch.zeros_like(mean_pred)
        
        return mean_pred
    
    def forward_with_uncertainty(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass with uncertainty estimation.
        
        Args:
            x: Input tensor
            
        Returns:
            Tuple of (mean, variance) predictions
        """
        predictions = []
        for network in self.ensemble:
            pred = network(x)
            # Check for NaN/Inf in individual predictions
            if torch.isnan(pred).any() or torch.isinf(pred).any():
                print("⚠️  NaN/Inf detected in ensemble member prediction")
                # Return safe values
                mean = torch.zeros_like(pred)
                variance = torch.ones_like(pred)
                return mean, variance
            predictions.append(pred)
        
        predictions = torch.stack(predictions, dim=0)  # [n_ensemble, batch_size, output_dim]
        
        mean = predictions.mean(dim=0)
        variance = predictions.var(dim=0)
        
        # Ensure variance is not zero (which can cause numerical issues)
        variance = torch.clamp(variance, min=1e-6)
        
        # Final check for NaN/Inf
        if torch.isnan(mean).any() or torch.isinf(mean).any():
            print("⚠️  NaN/Inf detected in mean calculation")
            mean = torch.zeros_like(mean)
        
        if torch.isnan(variance).any() or torch.isinf(variance).any():
            print("⚠️  NaN/Inf detected in variance calculation")
            variance = torch.ones_like(variance)
        
        return mean, variance
    
    def compute_total_loss(self, x_collocation: torch.Tensor, x_data: torch.Tensor, y_data: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Compute total loss for Bayesian PINN.
        
        Args:
            x_collocation: Collocation points for physics loss
            x_data: Data points
            y_data: Target values
            
        Returns:
            Dictionary containing different loss components
        """
        # Compute physics loss
        physics_loss = self.compute_physics_loss(x_collocation)
        
        # Compute data loss
        data_loss = self.compute_data_loss(x_data, y_data)
        
        # Compute uncertainty loss
        uncertainty_loss = self.compute_uncertainty_loss(x_data, y_data)
        
        # Total loss is combination of physics and data losses
        # For Bayesian PINN, we use data loss + uncertainty loss + physics loss
        total_loss = data_loss + uncertainty_loss + physics_loss
        
        # Check for NaN in losses
        if torch.isnan(total_loss) or torch.isinf(total_loss):
            print("⚠️  NaN/Inf detected in total loss computation")
            print(f"   Physics loss: {physics_loss.item():.6f}")
            print(f"   Data loss: {data_loss.item():.6f}")
            print(f"   Uncertainty loss: {uncertainty_loss.item():.6f}")
            # Return safe loss
            total_loss = torch.tensor(1.0, device=total_loss.device, requires_grad=True)
        
        return {
            'total_loss': total_loss,
            'physics_loss': physics_loss,
            'data_loss': data_loss,
            'uncertainty_loss': uncertainty_loss
        }
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute physics loss for ensemble.
        
        Args:
            x: Collocation points
            
        Returns:
            Physics loss
        """
        # This is a placeholder - specific implementations will override
        # For now, return zero loss
        return torch.tensor(0.0, device=x.device, requires_grad=True)
    
    def compute_data_loss(self, x_data: torch.Tensor, y_data: torch.Tensor) -> torch.Tensor:
        """
        Compute data loss for ensemble.
        
        Args:
            x_data: Input data points
            y_data: Target data points
            
        Returns:
            Data loss
        """
        mean_pred, _ = self.forward_with_uncertainty(x_data)
        return F.mse_loss(mean_pred, y_data)
    
    def compute_uncertainty_loss(self, x_data: torch.Tensor, y_data: torch.Tensor) -> torch.Tensor:
        """
        Compute uncertainty-aware loss using negative log-likelihood.
        
        Args:
            x_data: Input data points
            y_data: Target data points
            
        Returns:
            Uncertainty loss
        """
        mean_pred, variance = self.forward_with_uncertainty(x_data)
        
        # Check for NaN or Inf in predictions
        if torch.isnan(mean_pred).any() or torch.isinf(mean_pred).any():
            print("⚠️  NaN/Inf detected in mean predictions")
            return torch.tensor(1.0, device=mean_pred.device, requires_grad=True)
        
        if torch.isnan(variance).any() or torch.isinf(variance).any():
            print("⚠️  NaN/Inf detected in variance predictions")
            return torch.tensor(1.0, device=variance.device, requires_grad=True)
        
        # Ensure variance is positive and in reasonable range to avoid numerical issues
        # Use more conservative range to prevent overflow in inverse Poisson problem
        variance = torch.clamp(variance, min=1e-4, max=1e2)
        
        # Negative log-likelihood loss with improved numerical stability
        try:
            # Compute log variance first to avoid overflow
            log_var = torch.log(variance)
            
            # Check for NaN/Inf in log variance
            if torch.isnan(log_var).any() or torch.isinf(log_var).any():
                print("⚠️  NaN/Inf detected in log variance")
                return torch.tensor(1.0, device=log_var.device, requires_grad=True)
            
            # Compute squared error
            squared_error = (y_data - mean_pred)**2
            
            # Check for NaN/Inf in squared error
            if torch.isnan(squared_error).any() or torch.isinf(squared_error).any():
                print("⚠️  NaN/Inf detected in squared error")
                return torch.tensor(1.0, device=squared_error.device, requires_grad=True)
            
            # Compute normalized squared error
            normalized_error = squared_error / variance
            
            # Check for NaN/Inf in normalized error
            if torch.isnan(normalized_error).any() or torch.isinf(normalized_error).any():
                print("⚠️  NaN/Inf detected in normalized error")
                return torch.tensor(1.0, device=normalized_error.device, requires_grad=True)
            
            # Compute NLL loss with numerical stability
            nll_loss = 0.5 * (log_var + np.log(2 * np.pi) + normalized_error)
            
            # Check for NaN/Inf in final loss
            if torch.isnan(nll_loss).any() or torch.isinf(nll_loss).any():
                print("⚠️  NaN/Inf detected in NLL loss")
                return torch.tensor(1.0, device=nll_loss.device, requires_grad=True)
            
            # Clamp loss to reasonable range to prevent overflow
            # Use more conservative range for inverse Poisson problem
            nll_loss = torch.clamp(nll_loss, min=0.0, max=50.0)
            
            return nll_loss.mean()
            
        except Exception as e:
            print(f"⚠️  Error in uncertainty loss computation: {e}")
            return torch.tensor(1.0, device=mean_pred.device, requires_grad=True)
    
    def get_ensemble_predictions(self, x: torch.Tensor) -> torch.Tensor:
        """
        Get predictions from all ensemble members.
        
        Args:
            x: Input tensor
            
        Returns:
            Predictions from all ensemble members [n_ensemble, batch_size, output_dim]
        """
        predictions = []
        for network in self.ensemble:
            pred = network(x)
            predictions.append(pred)
        
        return torch.stack(predictions, dim=0)
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get model information."""
        info = super().get_model_info()
        info.update({
            "model_type": "BayesianPINN",
            "n_ensemble": self.n_ensemble,
            "dropout_rate": self.dropout_rate,
            "weight_decay": self.weight_decay,
        })
        return info


class ProblemSpecificBayesianPINN(BayesianPINN):
    """
    Base class for problem-specific Bayesian PINN implementations.
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    @abstractmethod
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """Compute problem-specific physics loss."""
        pass


class LorenzBayesianPINN(ProblemSpecificBayesianPINN):
    """Bayesian PINN for Lorenz system."""
    
    def __init__(
        self,
        sigma: float = 10.0,
        rho: float = 28.0,
        beta: float = 8.0/3.0,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.sigma = sigma
        self.rho = rho
        self.beta = beta
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """Compute Lorenz system physics loss for ensemble."""
        x.requires_grad_(True)
        
        # Get ensemble predictions
        ensemble_preds = self.get_ensemble_predictions(x)  # [n_ensemble, batch_size, 3]
        
        total_loss = 0.0
        valid_predictions = 0
        
        for pred in ensemble_preds:
            # Check for NaN/Inf in prediction
            if torch.isnan(pred).any() or torch.isinf(pred).any():
                print("⚠️  NaN/Inf detected in ensemble prediction, skipping...")
                continue
                
            try:
                # Compute time derivatives with error handling
                du_dt = torch.autograd.grad(
                    pred, x, 
                    grad_outputs=torch.ones_like(pred), 
                    create_graph=True, 
                    retain_graph=True,
                    allow_unused=True
                )[0]
                
                # Check for NaN/Inf in gradients
                if du_dt is None or torch.isnan(du_dt).any() or torch.isinf(du_dt).any():
                    print("⚠️  NaN/Inf detected in gradients, skipping...")
                    continue
                
                # Split predictions
                x_comp, y_comp, z_comp = pred[:, 0:1], pred[:, 1:2], pred[:, 2:3]
                dx_dt, dy_dt, dz_dt = du_dt[:, 0:1], du_dt[:, 1:2], du_dt[:, 2:3]
                
                # Lorenz equations
                residual_x = dx_dt - self.sigma * (y_comp - x_comp)
                residual_y = dy_dt - (x_comp * (self.rho - z_comp) - y_comp)
                residual_z = dz_dt - (x_comp * y_comp - self.beta * z_comp)
                
                # Check for NaN/Inf in residuals
                residuals = torch.cat([residual_x, residual_y, residual_z], dim=1)
                if torch.isnan(residuals).any() or torch.isinf(residuals).any():
                    print("⚠️  NaN/Inf detected in residuals, skipping...")
                    continue
                
                # Physics loss for this ensemble member
                physics_loss = torch.mean(residuals**2)
                
                # Final check
                if torch.isnan(physics_loss) or torch.isinf(physics_loss):
                    print("⚠️  NaN/Inf detected in physics loss, skipping...")
                    continue
                
                total_loss += physics_loss
                valid_predictions += 1
                
            except Exception as e:
                print(f"⚠️  Error in physics loss computation: {e}")
                continue
        
        if valid_predictions == 0:
            print("⚠️  No valid predictions for physics loss, returning safe value")
            return torch.tensor(1.0, device=x.device, requires_grad=True)
        
        return total_loss / valid_predictions


class BurgersBayesianPINN(ProblemSpecificBayesianPINN):
    """Bayesian PINN for 2D Burgers equation."""
    
    def __init__(self, nu: float = 0.01, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.nu = nu
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """Compute 2D Burgers equation physics loss for ensemble."""
        x.requires_grad_(True)
        
        # Get ensemble predictions
        ensemble_preds = self.get_ensemble_predictions(x)  # [n_ensemble, batch_size, 2]
        
        total_loss = 0.0
        valid_predictions = 0
        
        for pred in ensemble_preds:
            # Check for NaN/Inf in prediction
            if torch.isnan(pred).any() or torch.isinf(pred).any():
                print("⚠️  NaN/Inf detected in ensemble prediction, skipping...")
                continue
                
            try:
                # Compute gradients
                u = pred[:, 0:1]
                v = pred[:, 1:2]
                
                # Compute gradients with error handling
                du_dx = torch.autograd.grad(
                    outputs=u,
                    inputs=x,
                    grad_outputs=torch.ones_like(u),
                    create_graph=True,
                    retain_graph=True,
                    allow_unused=True
                )[0]
                
                dv_dx = torch.autograd.grad(
                    outputs=v,
                    inputs=x,
                    grad_outputs=torch.ones_like(v),
                    create_graph=True,
                    retain_graph=True,
                    allow_unused=True
                )[0]
                
                # Check for NaN/Inf in gradients
                if (du_dx is None or torch.isnan(du_dx).any() or torch.isinf(du_dx).any() or
                    dv_dx is None or torch.isnan(dv_dx).any() or torch.isinf(dv_dx).any()):
                    print("⚠️  NaN/Inf detected in gradients, skipping...")
                    continue
                
                # Extract partial derivatives (x, y, t order)
                du_dx_spatial = du_dx[:, 0:1]  # ∂u/∂x
                du_dy = du_dx[:, 1:2]          # ∂u/∂y
                du_dt = du_dx[:, 2:3]          # ∂u/∂t
                
                dv_dx_spatial = dv_dx[:, 0:1]  # ∂v/∂x
                dv_dy = dv_dx[:, 1:2]          # ∂v/∂y
                dv_dt = dv_dx[:, 2:3]          # ∂v/∂t
                
                # Compute second derivatives for Laplacian with error handling
                d2u_dx2 = torch.autograd.grad(
                    outputs=du_dx_spatial,
                    inputs=x,
                    grad_outputs=torch.ones_like(du_dx_spatial),
                    create_graph=True,
                    retain_graph=True,
                    allow_unused=True
                )[0]
                
                d2u_dy2 = torch.autograd.grad(
                    outputs=du_dy,
                    inputs=x,
                    grad_outputs=torch.ones_like(du_dy),
                    create_graph=True,
                    retain_graph=True,
                    allow_unused=True
                )[0]
                
                d2v_dx2 = torch.autograd.grad(
                    outputs=dv_dx_spatial,
                    inputs=x,
                    grad_outputs=torch.ones_like(dv_dx_spatial),
                    create_graph=True,
                    retain_graph=True,
                    allow_unused=True
                )[0]
                
                d2v_dy2 = torch.autograd.grad(
                    outputs=dv_dy,
                    inputs=x,
                    grad_outputs=torch.ones_like(dv_dy),
                    create_graph=True,
                    retain_graph=True,
                    allow_unused=True
                )[0]
                
                # Check for NaN/Inf in second derivatives
                if (d2u_dx2 is None or torch.isnan(d2u_dx2).any() or torch.isinf(d2u_dx2).any() or
                    d2u_dy2 is None or torch.isnan(d2u_dy2).any() or torch.isinf(d2u_dy2).any() or
                    d2v_dx2 is None or torch.isnan(d2v_dx2).any() or torch.isinf(d2v_dx2).any() or
                    d2v_dy2 is None or torch.isnan(d2v_dy2).any() or torch.isinf(d2v_dy2).any()):
                    print("⚠️  NaN/Inf detected in second derivatives, skipping...")
                    continue
                
                # Extract second derivatives
                d2u_dx2 = d2u_dx2[:, 0:1]
                d2u_dy2 = d2u_dy2[:, 1:2]
                d2v_dx2 = d2v_dx2[:, 0:1]
                d2v_dy2 = d2v_dy2[:, 1:2]
                
                # Burgers equation residuals
                residual_u = du_dt + u * du_dx_spatial + v * du_dy - self.nu * (d2u_dx2 + d2u_dy2)
                residual_v = dv_dt + u * dv_dx_spatial + v * dv_dy - self.nu * (d2v_dx2 + d2v_dy2)
                
                # Check for NaN/Inf in residuals
                residuals = torch.cat([residual_u, residual_v], dim=1)
                if torch.isnan(residuals).any() or torch.isinf(residuals).any():
                    print("⚠️  NaN/Inf detected in residuals, skipping...")
                    continue
                
                # Physics loss for this ensemble member
                physics_loss = torch.mean(residuals**2)
                
                # Final check
                if torch.isnan(physics_loss) or torch.isinf(physics_loss):
                    print("⚠️  NaN/Inf detected in physics loss, skipping...")
                    continue
                
                total_loss += physics_loss
                valid_predictions += 1
                
            except Exception as e:
                print(f"⚠️  Error in physics loss computation: {e}")
                continue
        
        if valid_predictions == 0:
            print("⚠️  No valid predictions for physics loss, returning safe value")
            return torch.tensor(1.0, device=x.device, requires_grad=True)
        
        return total_loss / valid_predictions


class InversePoissonBayesianPINN(ProblemSpecificBayesianPINN):
    """Bayesian PINN for 1D inverse Poisson problem."""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """Compute 1D Poisson equation physics loss for ensemble."""
        x.requires_grad_(True)
        
        # Get ensemble predictions
        ensemble_preds = self.get_ensemble_predictions(x)  # [n_ensemble, batch_size, 2]
        
        total_loss = 0.0
        valid_predictions = 0
        
        for pred in ensemble_preds:
            # Check for NaN/Inf in prediction
            if torch.isnan(pred).any() or torch.isinf(pred).any():
                continue  # Skip silently to reduce noise
                
            try:
                # Split predictions: u and f
                u = pred[:, 0:1]
                f = pred[:, 1:2]
                
                # Check input validity
                if torch.isnan(u).any() or torch.isinf(u).any() or torch.isnan(f).any() or torch.isinf(f).any():
                    continue
                
                # Use finite difference approximation for second derivative (more stable)
                # This avoids the numerical instability of autograd for second derivatives
                eps = 1e-4
                x_plus = x + eps
                x_minus = x - eps
                
                # Get predictions at perturbed points
                u_plus = self._get_single_prediction(x_plus, pred_idx=0)  # u component
                u_minus = self._get_single_prediction(x_minus, pred_idx=0)  # u component
                
                if (torch.isnan(u_plus).any() or torch.isinf(u_plus).any() or 
                    torch.isnan(u_minus).any() or torch.isinf(u_minus).any()):
                    continue
                
                # Finite difference second derivative: d²u/dx² ≈ (u(x+h) - 2u(x) + u(x-h)) / h²
                d2u_dx2 = (u_plus - 2 * u + u_minus) / (eps**2)
                
                # Check for NaN/Inf in second derivative
                if torch.isnan(d2u_dx2).any() or torch.isinf(d2u_dx2).any():
                    continue
                
                # Poisson equation residual: -d²u/dx² = f
                residual = -d2u_dx2 - f
                
                # Check for NaN/Inf in residual
                if torch.isnan(residual).any() or torch.isinf(residual).any():
                    continue
                
                # Physics loss for this ensemble member
                physics_loss = torch.mean(residual**2)
                
                # Final check
                if torch.isnan(physics_loss) or torch.isinf(physics_loss):
                    continue
                
                total_loss += physics_loss
                valid_predictions += 1
                
            except Exception as e:
                continue  # Skip silently to reduce noise
        
        if valid_predictions == 0:
            return torch.tensor(1.0, device=x.device, requires_grad=True)
        
        return total_loss / valid_predictions
    
    def _get_single_prediction(self, x: torch.Tensor, pred_idx: int = 0) -> torch.Tensor:
        """Get prediction from a single ensemble member."""
        if pred_idx < len(self.ensemble):
            return self.ensemble[pred_idx](x)[:, 0:1]  # Return u component only
        else:
            return torch.zeros(x.shape[0], 1, device=x.device)


def create_bayesian_problem_model(
    problem_type: str,
    input_dim: int,
    output_dim: int,
    hidden_layers: List[int] = [50, 50, 50],
    activation: str = "tanh",
    n_ensemble: int = 5,
    dropout_rate: float = 0.1,
    **kwargs
) -> ProblemSpecificBayesianPINN:
    """
    Factory function to create problem-specific Bayesian PINN models.
    
    Args:
        problem_type: Type of problem ("lorenz", "burgers", "inverse_poisson")
        input_dim: Input dimension
        output_dim: Output dimension
        hidden_layers: Hidden layer dimensions
        activation: Activation function
        n_ensemble: Number of ensemble members
        dropout_rate: Dropout rate
        **kwargs: Additional problem-specific parameters
        
    Returns:
        Problem-specific Bayesian PINN model
    """
    if problem_type == "lorenz":
        return LorenzBayesianPINN(
            input_dim=input_dim,
            output_dim=output_dim,
            hidden_layers=hidden_layers,
            activation=activation,
            n_ensemble=n_ensemble,
            dropout_rate=dropout_rate,
            **kwargs
        )
    elif problem_type == "burgers":
        return BurgersBayesianPINN(
            input_dim=input_dim,
            output_dim=output_dim,
            hidden_layers=hidden_layers,
            activation=activation,
            n_ensemble=n_ensemble,
            dropout_rate=dropout_rate,
            **kwargs
        )
    elif problem_type == "inverse_poisson":
        return InversePoissonBayesianPINN(
            input_dim=input_dim,
            output_dim=output_dim,
            hidden_layers=hidden_layers,
            activation=activation,
            n_ensemble=n_ensemble,
            dropout_rate=dropout_rate,
            **kwargs
        )
    else:
        raise ValueError(f"Unknown problem type: {problem_type}")
