"""
Problem-specific PINN models with correct physics loss implementations.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Any, Tuple
from .base_pinn import BasePINN
from .rpit_pinn import RPITPINN


class LorenzStandardPINN(BasePINN):
    """Standard PINN for Lorenz system with correct physics loss."""
    
    def __init__(self, sigma: float = 10.0, rho: float = 28.0, beta: float = 8.0/3.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sigma = sigma
        self.rho = rho
        self.beta = beta
    
    def compute_data_loss(self, x_data: torch.Tensor, y_data: torch.Tensor) -> torch.Tensor:
        """Compute data loss using MSE."""
        predictions = self.forward(x_data)
        return F.mse_loss(predictions, y_data)
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute Lorenz system physics loss.
        
        Lorenz system:
        dx/dt = σ(y - x)
        dy/dt = x(ρ - z) - y
        dz/dt = xy - βz
        """
        x.requires_grad_(True)
        
        # Forward pass
        u = self.forward(x)  # Shape: (batch_size, 3) for [x, y, z]
        
        # Compute time derivatives
        du_dt = torch.autograd.grad(
            outputs=u,
            inputs=x,
            grad_outputs=torch.ones_like(u),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Extract components
        x_comp = u[:, 0:1]
        y_comp = u[:, 1:2]
        z_comp = u[:, 2:3]
        
        dx_dt = du_dt[:, 0:1]
        dy_dt = du_dt[:, 1:2]
        dz_dt = du_dt[:, 2:3]
        
        # Compute residuals
        residual_x = dx_dt - self.sigma * (y_comp - x_comp)
        residual_y = dy_dt - (x_comp * (self.rho - z_comp) - y_comp)
        residual_z = dz_dt - (x_comp * y_comp - self.beta * z_comp)
        
        # Combine residuals
        residual = torch.cat([residual_x, residual_y, residual_z], dim=1)
        
        return torch.mean(residual**2)


class LorenzRPITPINN(RPITPINN):
    """R-PIT PINN for Lorenz system with correct physics loss."""
    
    def __init__(self, sigma: float = 10.0, rho: float = 28.0, beta: float = 8.0/3.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sigma = sigma
        self.rho = rho
        self.beta = beta
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute Lorenz system physics loss for R-PIT.
        
        Lorenz system:
        dx/dt = σ(y - x)
        dy/dt = x(ρ - z) - y
        dz/dt = xy - βz
        """
        x.requires_grad_(True)
        
        # Forward pass - use mean prediction for physics loss
        if self.uncertainty_output:
            output = self.forward(x, add_noise=False)
            u = output[:, :self.base_output_dim]  # Use mean only
        else:
            u = self.forward(x)
        
        # Compute time derivatives
        du_dt = torch.autograd.grad(
            outputs=u,
            inputs=x,
            grad_outputs=torch.ones_like(u),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Extract components
        x_comp = u[:, 0:1]
        y_comp = u[:, 1:2]
        z_comp = u[:, 2:3]
        
        dx_dt = du_dt[:, 0:1]
        dy_dt = du_dt[:, 1:2]
        dz_dt = du_dt[:, 2:3]
        
        # Compute residuals
        residual_x = dx_dt - self.sigma * (y_comp - x_comp)
        residual_y = dy_dt - (x_comp * (self.rho - z_comp) - y_comp)
        residual_z = dz_dt - (x_comp * y_comp - self.beta * z_comp)
        
        # Combine residuals
        residual = torch.cat([residual_x, residual_y, residual_z], dim=1)
        
        return torch.mean(residual**2)


class BurgersStandardPINN(BasePINN):
    """Standard PINN for 2D Burgers equation with correct physics loss."""
    
    def __init__(self, nu: float = 0.01, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.nu = nu  # Viscosity coefficient
    
    def compute_data_loss(self, x_data: torch.Tensor, y_data: torch.Tensor) -> torch.Tensor:
        """Compute data loss using MSE."""
        predictions = self.forward(x_data)
        return F.mse_loss(predictions, y_data)
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute 2D Burgers equation physics loss.
        
        2D Burgers equation:
        ∂u/∂t + u·∇u = ν∇²u
        """
        x.requires_grad_(True)
        
        # Forward pass
        u = self.forward(x)  # Shape: (batch_size, 2) for [u, v]
        
        # Extract components
        u_comp = u[:, 0:1]
        v_comp = u[:, 1:2]
        
        # Compute gradients
        du_dx = torch.autograd.grad(
            outputs=u_comp,
            inputs=x,
            grad_outputs=torch.ones_like(u_comp),
            create_graph=True,
            retain_graph=True
        )[0]
        
        dv_dx = torch.autograd.grad(
            outputs=v_comp,
            inputs=x,
            grad_outputs=torch.ones_like(v_comp),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Extract partial derivatives
        du_dx_spatial = du_dx[:, 0:1]  # ∂u/∂x
        du_dy = du_dx[:, 1:2]          # ∂u/∂y
        du_dt = du_dx[:, 2:3]          # ∂u/∂t
        
        dv_dx_spatial = dv_dx[:, 0:1]  # ∂v/∂x
        dv_dy = dv_dx[:, 1:2]          # ∂v/∂y
        dv_dt = dv_dx[:, 2:3]          # ∂v/∂t
        
        # Compute second derivatives for Laplacian
        d2u_dx2 = torch.autograd.grad(
            outputs=du_dx_spatial,
            inputs=x,
            grad_outputs=torch.ones_like(du_dx_spatial),
            create_graph=True,
            retain_graph=True
        )[0][:, 0:1]
        
        d2u_dy2 = torch.autograd.grad(
            outputs=du_dy,
            inputs=x,
            grad_outputs=torch.ones_like(du_dy),
            create_graph=True,
            retain_graph=True
        )[0][:, 1:2]
        
        d2v_dx2 = torch.autograd.grad(
            outputs=dv_dx_spatial,
            inputs=x,
            grad_outputs=torch.ones_like(dv_dx_spatial),
            create_graph=True,
            retain_graph=True
        )[0][:, 0:1]
        
        d2v_dy2 = torch.autograd.grad(
            outputs=dv_dy,
            inputs=x,
            grad_outputs=torch.ones_like(dv_dy),
            create_graph=True,
            retain_graph=True
        )[0][:, 1:2]
        
        # Compute residuals
        residual_u = du_dt + u_comp * du_dx_spatial + v_comp * du_dy - self.nu * (d2u_dx2 + d2u_dy2)
        residual_v = dv_dt + u_comp * dv_dx_spatial + v_comp * dv_dy - self.nu * (d2v_dx2 + d2v_dy2)
        
        # Combine residuals
        residual = torch.cat([residual_u, residual_v], dim=1)
        
        return torch.mean(residual**2)


class BurgersRPITPINN(RPITPINN):
    """R-PIT PINN for 2D Burgers equation with correct physics loss."""
    
    def __init__(self, nu: float = 0.01, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.nu = nu  # Viscosity coefficient
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute 2D Burgers equation physics loss for R-PIT.
        
        2D Burgers equation:
        ∂u/∂t + u·∇u = ν∇²u
        """
        x.requires_grad_(True)
        
        # Forward pass - use mean prediction for physics loss
        if self.uncertainty_output:
            output = self.forward(x, add_noise=False)
            u = output[:, :self.base_output_dim]  # Use mean only
        else:
            u = self.forward(x)
        
        # Extract components
        u_comp = u[:, 0:1]
        v_comp = u[:, 1:2]
        
        # Compute gradients
        du_dx = torch.autograd.grad(
            outputs=u_comp,
            inputs=x,
            grad_outputs=torch.ones_like(u_comp),
            create_graph=True,
            retain_graph=True
        )[0]
        
        dv_dx = torch.autograd.grad(
            outputs=v_comp,
            inputs=x,
            grad_outputs=torch.ones_like(v_comp),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Extract partial derivatives
        du_dx_spatial = du_dx[:, 0:1]  # ∂u/∂x
        du_dy = du_dx[:, 1:2]          # ∂u/∂y
        du_dt = du_dx[:, 2:3]          # ∂u/∂t
        
        dv_dx_spatial = dv_dx[:, 0:1]  # ∂v/∂x
        dv_dy = dv_dx[:, 1:2]          # ∂v/∂y
        dv_dt = dv_dx[:, 2:3]          # ∂v/∂t
        
        # Compute second derivatives for Laplacian
        d2u_dx2 = torch.autograd.grad(
            outputs=du_dx_spatial,
            inputs=x,
            grad_outputs=torch.ones_like(du_dx_spatial),
            create_graph=True,
            retain_graph=True
        )[0][:, 0:1]
        
        d2u_dy2 = torch.autograd.grad(
            outputs=du_dy,
            inputs=x,
            grad_outputs=torch.ones_like(du_dy),
            create_graph=True,
            retain_graph=True
        )[0][:, 1:2]
        
        d2v_dx2 = torch.autograd.grad(
            outputs=dv_dx_spatial,
            inputs=x,
            grad_outputs=torch.ones_like(dv_dx_spatial),
            create_graph=True,
            retain_graph=True
        )[0][:, 0:1]
        
        d2v_dy2 = torch.autograd.grad(
            outputs=dv_dy,
            inputs=x,
            grad_outputs=torch.ones_like(dv_dy),
            create_graph=True,
            retain_graph=True
        )[0][:, 1:2]
        
        # Compute residuals
        residual_u = du_dt + u_comp * du_dx_spatial + v_comp * du_dy - self.nu * (d2u_dx2 + d2u_dy2)
        residual_v = dv_dt + u_comp * dv_dx_spatial + v_comp * dv_dy - self.nu * (d2v_dx2 + d2v_dy2)
        
        # Combine residuals
        residual = torch.cat([residual_u, residual_v], dim=1)
        
        return torch.mean(residual**2)


class InversePoissonStandardPINN(BasePINN):
    """Standard PINN for 1D inverse Poisson problem with correct physics loss."""
    
    def __init__(self, source_strength: float = 1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.source_strength = source_strength
    
    def compute_data_loss(self, x_data: torch.Tensor, y_data: torch.Tensor) -> torch.Tensor:
        """Compute data loss using MSE."""
        predictions = self.forward(x_data)
        return F.mse_loss(predictions, y_data)
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute 1D Poisson equation physics loss.
        
        1D Poisson equation: -∇²u = f(x)
        where f(x) = sin(2πx) + 0.5sin(4πx)
        """
        x.requires_grad_(True)
        
        # Forward pass
        u = self.forward(x)  # Shape: (batch_size, 1)
        
        # Compute first derivative
        du_dx = torch.autograd.grad(
            outputs=u,
            inputs=x,
            grad_outputs=torch.ones_like(u),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Compute second derivative
        d2u_dx2 = torch.autograd.grad(
            outputs=du_dx,
            inputs=x,
            grad_outputs=torch.ones_like(du_dx),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Compute source term
        f = torch.sin(2 * np.pi * x) + 0.5 * torch.sin(4 * np.pi * x)
        
        # Compute residual: -∇²u - f = 0
        residual = -d2u_dx2 - f
        
        return torch.mean(residual**2)


class InversePoissonRPITPINN(RPITPINN):
    """R-PIT PINN for 1D inverse Poisson problem with correct physics loss."""
    
    def __init__(self, source_strength: float = 1.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.source_strength = source_strength
    
    def compute_physics_loss(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute 1D Poisson equation physics loss for R-PIT.
        
        1D Poisson equation: -∇²u = f(x)
        where f(x) = sin(2πx) + 0.5sin(4πx)
        """
        x.requires_grad_(True)
        
        # Forward pass - use mean prediction for physics loss
        if self.uncertainty_output:
            output = self.forward(x, add_noise=False)
            u = output[:, :self.base_output_dim]  # Use mean only
        else:
            u = self.forward(x)
        
        # Compute first derivative
        du_dx = torch.autograd.grad(
            outputs=u,
            inputs=x,
            grad_outputs=torch.ones_like(u),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Compute second derivative
        d2u_dx2 = torch.autograd.grad(
            outputs=du_dx,
            inputs=x,
            grad_outputs=torch.ones_like(du_dx),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Compute source term
        f = torch.sin(2 * np.pi * x) + 0.5 * torch.sin(4 * np.pi * x)
        
        # Compute residual: -∇²u - f = 0
        residual = -d2u_dx2 - f
        
        return torch.mean(residual**2)


# Factory function to create problem-specific models
def create_problem_model(problem_type: str, method: str = "standard", **kwargs):
    """
    Factory function to create problem-specific models.
    
    Args:
        problem_type: Type of problem ('lorenz', 'burgers', 'inverse_poisson')
        method: Method type ('standard', 'rpit', or 'bayesian')
        **kwargs: Additional arguments for model creation
        
    Returns:
        Problem-specific model instance
    """
    # Import Bayesian models
    from .bayesian_pinn import create_bayesian_problem_model
    
    models = {
        'lorenz': {
            'standard': LorenzStandardPINN,
            'rpit': LorenzRPITPINN,
            'bayesian': lambda **kw: create_bayesian_problem_model('lorenz', **kw)
        },
        'burgers': {
            'standard': BurgersStandardPINN,
            'rpit': BurgersRPITPINN,
            'bayesian': lambda **kw: create_bayesian_problem_model('burgers', **kw)
        },
        'inverse_poisson': {
            'standard': InversePoissonStandardPINN,
            'rpit': InversePoissonRPITPINN,
            'bayesian': lambda **kw: create_bayesian_problem_model('inverse_poisson', **kw)
        }
    }
    
    if problem_type not in models:
        raise ValueError(f"Unknown problem type: {problem_type}")
    
    if method not in models[problem_type]:
        raise ValueError(f"Unknown method: {method}")
    
    return models[problem_type][method](**kwargs)
