"""
PDE differential operators for computing physics residuals.
Implements automatic differentiation for various PDEs.
"""

import torch
import torch.nn.functional as F
from typing import Tuple, Optional, Callable


def compute_derivatives_1d(
    u: torch.Tensor,
    x: torch.Tensor,
    max_order: int = 2,
    create_graph: bool = True
) -> dict:
    """
    Compute derivatives up to given order for 1D problems.
    
    Args:
        u: Field values, shape (batch, n_points)
        x: Spatial coordinates, shape (batch, n_points, 1)
        max_order: Maximum derivative order
        create_graph: Whether to create computation graph
    
    Returns:
        Dictionary with keys 'u', 'du_dx', 'd2u_dx2', etc.
    """
    if u.dim() == 3:
        u = u.squeeze(1)
    
    # Ensure x requires grad
    if not x.requires_grad:
        x = x.clone().detach().requires_grad_(True)
    
    # Extract 1D coordinate
    x_1d = x.squeeze(-1)  # (batch, n_points)
    
    # Ensure x_1d requires grad
    if not x_1d.requires_grad:
        x_1d = x_1d.clone().detach().requires_grad_(True)
    
    derivatives = {'u': u}
    
    # Try autograd first, fall back to finite differences if needed
    use_finite_diff = False
    
    if max_order >= 1:
        # Compute first derivative
        grad_outputs = torch.ones_like(u)
        
        try:
            grads = torch.autograd.grad(
                outputs=u,
                inputs=x_1d,
                grad_outputs=grad_outputs,
                create_graph=(max_order > 1) or create_graph,
                retain_graph=True,
                allow_unused=True
            )[0]
            
            if grads is None:
                use_finite_diff = True
            else:
                du_dx = grads
        except RuntimeError:
            use_finite_diff = True
        
        if use_finite_diff:
            # Fall back to finite differences
            dx = (x_1d[:, 1] - x_1d[:, 0]).mean().item()
            du_dx = finite_difference_gradient_1d(u, dx)
            
            # Only warn once per session
            if not hasattr(compute_derivatives_1d, '_warned'):
                print("Warning: Using finite differences for derivatives (autograd unavailable)")
                compute_derivatives_1d._warned = True
        
        derivatives['du_dx'] = du_dx
        
        if max_order >= 2:
            if use_finite_diff or not du_dx.requires_grad:
                # Use finite differences for second derivative
                dx = (x_1d[:, 1] - x_1d[:, 0]).mean().item()
                d2u_dx2 = finite_difference_laplacian_1d(u, dx)
            else:
                # Try autograd for second derivative
                try:
                    grad_outputs_2 = torch.ones_like(du_dx)
                    grads_2 = torch.autograd.grad(
                        outputs=du_dx,
                        inputs=x_1d,
                        grad_outputs=grad_outputs_2,
                        create_graph=create_graph,
                        retain_graph=True,
                        allow_unused=True
                    )[0]
                    
                    if grads_2 is None:
                        dx = (x_1d[:, 1] - x_1d[:, 0]).mean().item()
                        d2u_dx2 = finite_difference_laplacian_1d(u, dx)
                    else:
                        d2u_dx2 = grads_2
                except RuntimeError:
                    dx = (x_1d[:, 1] - x_1d[:, 0]).mean().item()
                    d2u_dx2 = finite_difference_laplacian_1d(u, dx)
            
            derivatives['d2u_dx2'] = d2u_dx2
    
    return derivatives

def gradient(
    u: torch.Tensor,
    x: torch.Tensor,
    create_graph: bool = True
) -> torch.Tensor:
    """
    Compute gradient du/dx using automatic differentiation.
    
    Args:
        u: Field values, shape (batch, ..., n_points)
        x: Spatial coordinates, shape (batch, n_points, spatial_dim)
        create_graph: Whether to create computation graph for higher derivatives
    
    Returns:
        Gradient, shape (batch, ..., n_points, spatial_dim)
    """
    if not x.requires_grad:
        x = x.clone().detach().requires_grad_(True)
    
    # Compute gradient for each spatial dimension
    batch_size = x.shape[0]
    n_points = x.shape[1]
    spatial_dim = x.shape[2]
    
    grads = []
    for i in range(spatial_dim):
        grad_outputs = torch.ones_like(u)
        grad_i = torch.autograd.grad(
            outputs=u,
            inputs=x,
            grad_outputs=grad_outputs,
            create_graph=create_graph,
            retain_graph=True,
            allow_unused=True
        )[0]
        
        if grad_i is None:
            grad_i = torch.zeros_like(x)
        
        # Extract i-th component
        grads.append(grad_i[..., i:i+1])
    
    return torch.cat(grads, dim=-1)


def divergence(
    v: torch.Tensor,
    x: torch.Tensor,
    create_graph: bool = True
) -> torch.Tensor:
    """
    Compute divergence ∇·v.
    
    Args:
        v: Vector field, shape (batch, spatial_dim, n_points)
        x: Spatial coordinates, shape (batch, n_points, spatial_dim)
        create_graph: Whether to create computation graph
    
    Returns:
        Divergence, shape (batch, n_points)
    """
    if not x.requires_grad:
        x = x.clone().detach().requires_grad_(True)
    
    spatial_dim = x.shape[2]
    div = 0
    
    for i in range(spatial_dim):
        v_i = v[:, i, :]  # i-th component
        
        grad_outputs = torch.ones_like(v_i)
        grad = torch.autograd.grad(
            outputs=v_i,
            inputs=x,
            grad_outputs=grad_outputs,
            create_graph=create_graph,
            retain_graph=True,
            allow_unused=True
        )[0]
        
        if grad is not None:
            # Add ∂v_i/∂x_i to divergence
            div = div + grad[..., i]
    
    return div


def laplacian(
    u: torch.Tensor,
    x: torch.Tensor,
    create_graph: bool = True
) -> torch.Tensor:
    """
    Compute Laplacian ∇²u = Σ ∂²u/∂x_i².
    
    Args:
        u: Field values, shape (batch, n_points) or (batch, 1, n_points)
        x: Spatial coordinates, shape (batch, n_points, spatial_dim)
        create_graph: Whether to create computation graph
    
    Returns:
        Laplacian, shape (batch, n_points)
    """
    # Ensure proper shape
    if u.dim() == 3:
        u = u.squeeze(1)
    
    if not x.requires_grad:
        x = x.clone().detach().requires_grad_(True)
    
    spatial_dim = x.shape[2]
    lapl = torch.zeros_like(u)
    
    # Compute first derivatives
    grad_outputs = torch.ones_like(u)
    du_dx = torch.autograd.grad(
        outputs=u,
        inputs=x,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        allow_unused=True
    )[0]
    
    if du_dx is None:
        return lapl
    
    # Compute second derivatives
    for i in range(spatial_dim):
        du_dxi = du_dx[..., i]
        
        grad_outputs = torch.ones_like(du_dxi)
        d2u_dxi2_full = torch.autograd.grad(
            outputs=du_dxi,
            inputs=x,
            grad_outputs=grad_outputs,
            create_graph=create_graph,
            retain_graph=True,
            allow_unused=True
        )[0]
        
        if d2u_dxi2_full is not None:
            d2u_dxi2 = d2u_dxi2_full[..., i]
            lapl = lapl + d2u_dxi2
    
    return lapl


class PDEOperator:
    """Base class for PDE operators."""
    
    def __init__(self, spatial_dim: int):
        self.spatial_dim = spatial_dim
    
    def __call__(
        self,
        u: torch.Tensor,
        x: torch.Tensor,
        lambda_params: Optional[dict] = None
    ) -> torch.Tensor:
        """
        Compute PDE residual.
        
        Args:
            u: Field values
            x: Spatial coordinates
            lambda_params: PDE parameters
        
        Returns:
            Residual values
        """
        raise NotImplementedError


class NonlinearPoissonOperator(PDEOperator):
    """
    Operator for 1D nonlinear Poisson equation:
    d/dx[k(u,x) * du/dx] - w = 0
    """
    
    def __init__(self):
        super().__init__(spatial_dim=1)
    
    def diffusion_coefficient(
        self,
        u: torch.Tensor,
        x: torch.Tensor,
        xi: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute k(u,x) = log(1 + exp(u * Σ ξ_i φ_i(x))) + 0.1
        
        Args:
            u: Field values, shape (batch, n_points)
            x: Coordinates, shape (batch, n_points, 1)
            xi: Chebyshev coefficients, shape (batch, n_coeffs)
        
        Returns:
            Diffusion coefficient, shape (batch, n_points)
        """
        # Compute Chebyshev basis
        x_norm = x.squeeze(-1)  # (batch, n_points)
        x_clamped = torch.clamp(x_norm, -1.0, 1.0)
        
        # Compute Chebyshev polynomials
        n_coeffs = xi.shape[1]
        basis = []
        
        T_0 = torch.ones_like(x_clamped)
        T_1 = x_clamped
        
        basis.append(T_0)
        if n_coeffs > 1:
            basis.append(T_1)
        
        for i in range(2, n_coeffs):
            T_next = 2.0 * x_clamped * basis[-1] - basis[-2]
            basis.append(T_next)
        
        # Stack: (batch, n_points, n_coeffs)
        basis_tensor = torch.stack(basis, dim=-1)
        
        # Compute Σ ξ_i φ_i(x)
        # xi: (batch, n_coeffs) -> (batch, 1, n_coeffs)
        xi_expanded = xi.unsqueeze(1)
        poly_sum = (basis_tensor * xi_expanded).sum(dim=-1)
        
        # Compute k(u,x)
        k = torch.log(1.0 + torch.exp(u * poly_sum)) + 0.1
        
        return k
    
    def __call__(
        self,
        u: torch.Tensor,
        x: torch.Tensor,
        lambda_params: dict
    ) -> torch.Tensor:
        """
        Compute residual: d/dx[k(u,x) * du/dx] - w
        
        Args:
            u: Field values, shape (batch, 1, n_points)
            x: Coordinates, shape (batch, n_points, 1)
            lambda_params: Dict with 'xi' and 'w'
        
        Returns:
            Residual, shape (batch, n_points)
        """
        if u.dim() == 3:
            u = u.squeeze(1)
        
        xi = lambda_params['xi']
        w = lambda_params['w'].squeeze(-1)  # (batch,)
        
        # Compute derivatives
        derivs = compute_derivatives_1d(u, x, max_order=2)
        du_dx = derivs['du_dx']
        d2u_dx2 = derivs['d2u_dx2']
        
        # Compute diffusion coefficient
        k = self.diffusion_coefficient(u, x, xi)
        
        # Compute dk/dx using finite differences (more stable)
        dx = (x[:, 1, 0] - x[:, 0, 0]).mean().item()
        dk_dx = finite_difference_gradient_1d(k, dx)
        
        # Residual: d/dx[k * du/dx] - w = dk/dx * du/dx + k * d2u/dx2 - w
        residual = dk_dx * du_dx + k * d2u_dx2 - w.unsqueeze(-1)
        
        return residual


def finite_difference_gradient_1d(
    u: torch.Tensor,
    dx: float
) -> torch.Tensor:
    """
    Compute gradient using finite differences (for regular grids).
    
    Args:
        u: Field values on grid, shape (batch, n_points)
        dx: Grid spacing
    
    Returns:
        Gradient, shape (batch, n_points)
    """
    # Central differences (interior)
    du_dx = torch.zeros_like(u)
    du_dx[:, 1:-1] = (u[:, 2:] - u[:, :-2]) / (2 * dx)
    
    # Forward difference (left boundary)
    du_dx[:, 0] = (u[:, 1] - u[:, 0]) / dx
    
    # Backward difference (right boundary)
    du_dx[:, -1] = (u[:, -1] - u[:, -2]) / dx
    
    return du_dx


def finite_difference_laplacian_1d(
    u: torch.Tensor,
    dx: float
) -> torch.Tensor:
    """
    Compute Laplacian using finite differences.
    
    Args:
        u: Field values, shape (batch, n_points)
        dx: Grid spacing
    
    Returns:
        Laplacian, shape (batch, n_points)
    """
    # Central differences
    lapl = torch.zeros_like(u)
    lapl[:, 1:-1] = (u[:, 2:] - 2 * u[:, 1:-1] + u[:, :-2]) / (dx ** 2)
    
    # Boundaries (using one-sided differences)
    lapl[:, 0] = (u[:, 1] - 2 * u[:, 0] + u[:, 0]) / (dx ** 2)
    lapl[:, -1] = (u[:, -1] - 2 * u[:, -1] + u[:, -2]) / (dx ** 2)
    
    return lapl