"""
Physics loss functions for different PDEs.
"""
import torch
import torch.nn.functional as F
from typing import Dict, Any, Callable
import numpy as np


class PhysicsLoss:
    """Base class for physics loss computation."""
    
    def __init__(self, device: str = "cpu"):
        self.device = device
    
    def compute_residual(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        """
        Compute PDE residual.
        
        Args:
            model: Neural network model
            x: Collocation points
            
        Returns:
            Residual tensor
        """
        raise NotImplementedError("Subclasses must implement compute_residual")


class LorenzPhysicsLoss(PhysicsLoss):
    """Physics loss for stochastic Lorenz system."""
    
    def __init__(self, sigma: float = 10.0, rho: float = 28.0, beta: float = 8.0/3.0, device: str = "cpu"):
        super().__init__(device)
        self.sigma = sigma
        self.rho = rho
        self.beta = beta
    
    def compute_residual(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        """
        Compute Lorenz system residual.
        
        The Lorenz system is:
        dx/dt = σ(y - x)
        dy/dt = x(ρ - z) - y
        dz/dt = xy - βz
        
        Args:
            model: Neural network model
            x: Time points (batch_size, 1)
            
        Returns:
            Residual tensor
        """
        x.requires_grad_(True)
        
        # Forward pass
        u = model(x)  # Shape: (batch_size, 3) for [x, y, z]
        
        # Compute derivatives
        du_dt = torch.autograd.grad(
            outputs=u,
            inputs=x,
            grad_outputs=torch.ones_like(u),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Extract components
        x_comp = u[:, 0:1]
        y_comp = u[:, 1:2]
        z_comp = u[:, 2:3]
        
        dx_dt = du_dt[:, 0:1]
        dy_dt = du_dt[:, 1:2]
        dz_dt = du_dt[:, 2:3]
        
        # Compute residuals
        residual_x = dx_dt - self.sigma * (y_comp - x_comp)
        residual_y = dy_dt - (x_comp * (self.rho - z_comp) - y_comp)
        residual_z = dz_dt - (x_comp * y_comp - self.beta * z_comp)
        
        # Combine residuals
        residual = torch.cat([residual_x, residual_y, residual_z], dim=1)
        
        return residual


class PoissonPhysicsLoss(PhysicsLoss):
    """Physics loss for 1D Poisson equation: -∇²u = f(x)."""
    
    def __init__(self, source_function: Callable = None, device: str = "cpu"):
        super().__init__(device)
        self.source_function = source_function or self._default_source
    
    def _default_source(self, x: torch.Tensor) -> torch.Tensor:
        """Default source function: f(x) = sin(2πx) + 0.5sin(4πx)."""
        return torch.sin(2 * np.pi * x) + 0.5 * torch.sin(4 * np.pi * x)
    
    def compute_residual(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        """
        Compute Poisson equation residual.
        
        Args:
            model: Neural network model
            x: Spatial points (batch_size, 1)
            
        Returns:
            Residual tensor
        """
        x.requires_grad_(True)
        
        # Forward pass
        u = model(x)  # Shape: (batch_size, 1)
        
        # Compute first derivative
        du_dx = torch.autograd.grad(
            outputs=u,
            inputs=x,
            grad_outputs=torch.ones_like(u),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Compute second derivative
        d2u_dx2 = torch.autograd.grad(
            outputs=du_dx,
            inputs=x,
            grad_outputs=torch.ones_like(du_dx),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Compute source term
        f = self.source_function(x)
        
        # Compute residual: -∇²u - f = 0
        residual = -d2u_dx2 - f
        
        return residual


class BurgersPhysicsLoss(PhysicsLoss):
    """Physics loss for 2D Burgers' equation."""
    
    def __init__(self, nu: float = 0.01, device: str = "cpu"):
        super().__init__(device)
        self.nu = nu  # Viscosity coefficient
    
    def compute_residual(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        """
        Compute Burgers' equation residual.
        
        The 2D Burgers' equation is:
        ∂u/∂t + u·∇u = ν∇²u
        
        Args:
            model: Neural network model
            x: Spatiotemporal points (batch_size, 3) for [x, y, t]
            
        Returns:
            Residual tensor
        """
        x.requires_grad_(True)
        
        # Forward pass
        u = model(x)  # Shape: (batch_size, 2) for [u, v]
        
        # Extract components
        u_comp = u[:, 0:1]
        v_comp = u[:, 1:2]
        
        # Compute gradients
        du_dx = torch.autograd.grad(
            outputs=u_comp,
            inputs=x,
            grad_outputs=torch.ones_like(u_comp),
            create_graph=True,
            retain_graph=True
        )[0]
        
        dv_dx = torch.autograd.grad(
            outputs=v_comp,
            inputs=x,
            grad_outputs=torch.ones_like(v_comp),
            create_graph=True,
            retain_graph=True
        )[0]
        
        # Extract partial derivatives
        du_dx_spatial = du_dx[:, 0:1]  # ∂u/∂x
        du_dy = du_dx[:, 1:2]          # ∂u/∂y
        du_dt = du_dx[:, 2:3]          # ∂u/∂t
        
        dv_dx_spatial = dv_dx[:, 0:1]  # ∂v/∂x
        dv_dy = dv_dx[:, 1:2]          # ∂v/∂y
        dv_dt = dv_dx[:, 2:3]          # ∂v/∂t
        
        # Compute second derivatives for Laplacian
        d2u_dx2 = torch.autograd.grad(
            outputs=du_dx_spatial,
            inputs=x,
            grad_outputs=torch.ones_like(du_dx_spatial),
            create_graph=True,
            retain_graph=True
        )[0][:, 0:1]
        
        d2u_dy2 = torch.autograd.grad(
            outputs=du_dy,
            inputs=x,
            grad_outputs=torch.ones_like(du_dy),
            create_graph=True,
            retain_graph=True
        )[0][:, 1:2]
        
        d2v_dx2 = torch.autograd.grad(
            outputs=dv_dx_spatial,
            inputs=x,
            grad_outputs=torch.ones_like(dv_dx_spatial),
            create_graph=True,
            retain_graph=True
        )[0][:, 0:1]
        
        d2v_dy2 = torch.autograd.grad(
            outputs=dv_dy,
            inputs=x,
            grad_outputs=torch.ones_like(dv_dy),
            create_graph=True,
            retain_graph=True
        )[0][:, 1:2]
        
        # Compute residuals
        residual_u = du_dt + u_comp * du_dx_spatial + v_comp * du_dy - self.nu * (d2u_dx2 + d2u_dy2)
        residual_v = dv_dt + u_comp * dv_dx_spatial + v_comp * dv_dy - self.nu * (d2v_dx2 + d2v_dy2)
        
        # Combine residuals
        residual = torch.cat([residual_u, residual_v], dim=1)
        
        return residual


def get_physics_loss(problem_type: str, **kwargs) -> PhysicsLoss:
    """
    Factory function to get physics loss for different problems.
    
    Args:
        problem_type: Type of problem ('lorenz', 'poisson', 'burgers')
        **kwargs: Additional arguments for specific loss functions
        
    Returns:
        Physics loss instance
    """
    loss_functions = {
        'lorenz': LorenzPhysicsLoss,
        'poisson': PoissonPhysicsLoss,
        'burgers': BurgersPhysicsLoss,
    }
    
    if problem_type not in loss_functions:
        raise ValueError(f"Unknown problem type: {problem_type}")
    
    return loss_functions[problem_type](**kwargs)
