"""
PDE solver for generating ground truth solutions.
Uses finite differences for 1D nonlinear Poisson equation.
"""

import torch
import numpy as np
from typing import Tuple, Dict, Optional
from scipy.sparse import diags
from scipy.sparse.linalg import spsolve


class NonlinearPoissonSolver1D:
    """
    Solver for 1D nonlinear Poisson equation:
    d/dx[k(u,x) * du/dx] = w
    
    with Dirichlet boundary conditions: u(-1) = u(1) = 0
    
    Uses Newton's method with finite differences.
    """
    
    def __init__(
        self,
        n_points: int = 256,
        domain: Tuple[float, float] = (-1.0, 1.0),
        max_iter: int = 50,
        tol: float = 1e-6
    ):
        """
        Args:
            n_points: Number of grid points
            domain: Domain bounds (a, b)
            max_iter: Maximum Newton iterations
            tol: Convergence tolerance
        """
        self.n_points = n_points
        self.domain = domain
        self.max_iter = max_iter
        self.tol = tol
        
        # Grid
        self.x = np.linspace(domain[0], domain[1], n_points)
        self.dx = (domain[1] - domain[0]) / (n_points - 1)
    
    def chebyshev_basis(self, x: np.ndarray, n_terms: int = 5) -> np.ndarray:
        """
        Compute Chebyshev polynomial basis.
        
        Args:
            x: Points to evaluate, shape (n,)
            n_terms: Number of Chebyshev terms
        
        Returns:
            Basis functions, shape (n, n_terms)
        """
        # Normalize x to [-1, 1]
        x_norm = np.clip(x, -1.0, 1.0)
        
        basis = []
        T_0 = np.ones_like(x_norm)
        T_1 = x_norm
        
        basis.append(T_0)
        if n_terms > 1:
            basis.append(T_1)
        
        for i in range(2, n_terms):
            T_next = 2.0 * x_norm * basis[-1] - basis[-2]
            basis.append(T_next)
        
        return np.stack(basis, axis=1)
    
    def diffusion_coefficient(
        self,
        u: np.ndarray,
        x: np.ndarray,
        xi: np.ndarray
    ) -> np.ndarray:
        """
        Compute k(u,x) = log(1 + exp(u * Σ ξ_i T_i(x))) + 0.1
        
        Args:
            u: Solution values, shape (n,)
            x: Grid points, shape (n,)
            xi: Chebyshev coefficients, shape (n_coeffs,)
        
        Returns:
            Diffusion coefficient, shape (n,)
        """
        # Compute Chebyshev basis
        basis = self.chebyshev_basis(x, len(xi))
        
        # Compute Σ ξ_i T_i(x)
        poly_sum = np.dot(basis, xi)
        
        # Compute k(u,x)
        k = np.log(1.0 + np.exp(u * poly_sum)) + 0.1
        
        return k
    
    def diffusion_derivative_u(
        self,
        u: np.ndarray,
        x: np.ndarray,
        xi: np.ndarray
    ) -> np.ndarray:
        """
        Compute dk/du for Newton's method.
        
        Returns:
            dk/du, shape (n,)
        """
        basis = self.chebyshev_basis(x, len(xi))
        poly_sum = np.dot(basis, xi)
        
        exp_term = np.exp(u * poly_sum)
        dk_du = poly_sum * exp_term / (1.0 + exp_term)
        
        return dk_du
    
    def residual(
        self,
        u: np.ndarray,
        xi: np.ndarray,
        w: float
    ) -> np.ndarray:
        """
        Compute PDE residual.
        
        Args:
            u: Solution values, shape (n_points,)
            xi: Chebyshev coefficients
            w: Forcing term
        
        Returns:
            Residual, shape (n_points,)
        """
        n = len(u)
        residual = np.zeros(n)
        
        # Boundary conditions
        residual[0] = u[0]
        residual[-1] = u[-1]
        
        # Interior points
        for i in range(1, n-1):
            # Compute k at cell faces
            u_left = 0.5 * (u[i-1] + u[i])
            u_right = 0.5 * (u[i] + u[i+1])
            x_left = 0.5 * (self.x[i-1] + self.x[i])
            x_right = 0.5 * (self.x[i] + self.x[i+1])
            
            k_left = self.diffusion_coefficient(
                np.array([u_left]), np.array([x_left]), xi
            )[0]
            k_right = self.diffusion_coefficient(
                np.array([u_right]), np.array([x_right]), xi
            )[0]
            
            # Compute fluxes
            flux_left = k_left * (u[i] - u[i-1]) / self.dx
            flux_right = k_right * (u[i+1] - u[i]) / self.dx
            
            # Residual
            residual[i] = (flux_right - flux_left) / self.dx - w
        
        return residual
    
    def jacobian(
        self,
        u: np.ndarray,
        xi: np.ndarray,
        w: float
    ) -> np.ndarray:
        """
        Compute Jacobian matrix for Newton's method.
        
        Returns:
            Jacobian, shape (n_points, n_points)
        """
        n = len(u)
        J = np.zeros((n, n))
        
        # Boundary rows
        J[0, 0] = 1.0
        J[-1, -1] = 1.0
        
        # Interior points - use finite differences for Jacobian
        eps = 1e-7
        for i in range(1, n-1):
            for j in range(max(0, i-2), min(n, i+3)):
                u_plus = u.copy()
                u_plus[j] += eps
                
                r_plus = self.residual(u_plus, xi, w)[i]
                r = self.residual(u, xi, w)[i]
                
                J[i, j] = (r_plus - r) / eps
        
        return J
    
    def solve(
        self,
        xi: np.ndarray,
        w: float,
        u_init: Optional[np.ndarray] = None
    ) -> Tuple[np.ndarray, Dict]:
        """
        Solve the nonlinear Poisson equation using Newton's method.
        
        Args:
            xi: Chebyshev coefficients, shape (n_coeffs,)
            w: Forcing term (scalar)
            u_init: Initial guess
        
        Returns:
            u: Solution, shape (n_points,)
            info: Dictionary with convergence info
        """
        # Initial guess
        if u_init is None:
            u = np.zeros(self.n_points)
        else:
            u = u_init.copy()
        
        # Newton iteration
        for it in range(self.max_iter):
            # Compute residual
            r = self.residual(u, xi, w)
            
            # Check convergence
            res_norm = np.linalg.norm(r)
            if res_norm < self.tol:
                return u, {
                    'converged': True,
                    'iterations': it,
                    'residual_norm': res_norm
                }
            
            # Compute Jacobian
            J = self.jacobian(u, xi, w)
            
            # Solve linear system
            try:
                du = np.linalg.solve(J, -r)
            except np.linalg.LinAlgError:
                return u, {
                    'converged': False,
                    'iterations': it,
                    'residual_norm': res_norm,
                    'error': 'Singular Jacobian'
                }
            
            # Update
            u = u + du
        
        return u, {
            'converged': False,
            'iterations': self.max_iter,
            'residual_norm': np.linalg.norm(self.residual(u, xi, w)),
            'error': 'Max iterations reached'
        }
    
    def solve_batch(
        self,
        xi_batch: np.ndarray,
        w_batch: np.ndarray
    ) -> Tuple[np.ndarray, list]:
        """
        Solve for a batch of parameters.
        
        Args:
            xi_batch: Batch of Chebyshev coefficients, shape (batch_size, n_coeffs)
            w_batch: Batch of forcing terms, shape (batch_size,)
        
        Returns:
            solutions: Shape (batch_size, n_points)
            info_list: List of info dicts
        """
        batch_size = len(xi_batch)
        solutions = np.zeros((batch_size, self.n_points))
        info_list = []
        
        for i in range(batch_size):
            u, info = self.solve(xi_batch[i], w_batch[i])
            solutions[i] = u
            info_list.append(info)
            
            if not info['converged']:
                print(f"Warning: Sample {i} did not converge. "
                      f"Residual: {info['residual_norm']:.2e}")
        
        return solutions, info_list


def solve_nonlinear_poisson_torch(xi: torch.Tensor, w: torch.Tensor, 
                                  n_points: int = 128) -> torch.Tensor:
    """
    Solve nonlinear Poisson equation using PyTorch.
    
    Args:
        xi: Chebyshev coefficients, shape (batch_size, n_chebyshev)
        w: Nonlinearity strength, shape (batch_size,)
        n_points: Number of grid points
    
    Returns:
        Solutions on grid, shape (batch_size, n_points)
    """
    device = xi.device  # Use same device as input
    
    # Convert to numpy for solver
    xi_np = xi.cpu().numpy()
    w_np = w.cpu().numpy()
    
    # Solve with numpy solver
    solver = NonlinearPoissonSolver1D(n_points=n_points)
    solutions, _ = solver.solve_batch(xi_np, w_np)
    
    # Convert back to torch and move to correct device
    return torch.from_numpy(solutions).float().to(device)