"""
Poisson equation solver for TensorGalerkin
"""

import torch
import numpy as np
from typing import Optional, Callable

from .base import EquationDatasetStatic
from ..discretization import tri3, quad4, tri_gauss_points, quad_gauss_points
from ..utils import apply_zero_boundary


class PoissonEquation(EquationDatasetStatic):
    """
    Poisson equation solver: -∇⋅(a∇u) = f
    
    where:
    - a is the diffusion coefficient
    - f is the source term
    - u is the solution
    """
    
    def __init__(self, 
                 mesh,
                 a: float = 1.0,
                 dt: float = 0.01,
                 f: Optional[Callable] = None):
        """
        Initialize Poisson equation solver
        
        Parameters:
        -----------
            mesh: meshio.Mesh
                The computational mesh
            a: float
                Diffusion coefficient (default: 1.0)
            dt: float  
                Time step (not used for steady-state, kept for compatibility)
            f: Optional[Callable]
                Source function f(x) -> values at points
        """
        super().__init__()
        self.dt = dt
        self.a = a
        self.f = f if f is not None else lambda x: torch.ones(x.shape[0])
        self.update_mesh(mesh)

    def precompute(self, mesh):
        """Precompute finite element matrices and geometric quantities"""
        if "triangle" in mesh.cells_dict.keys():
            elements = mesh.cells_dict["triangle"]
            qpoints = tri_gauss_points(ngp=4)
            w, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
            # Extract only x,y coordinates (remove z coordinate)
            mesh_points_2d = mesh.points[elements][:, :, :2]
            shape_val, shape_grad, jac_det = tri3(xi, eta, mesh_points_2d, return_jacobian=False)
            n_basis = 3
        elif "quad" in mesh.cells_dict.keys():
            elements = mesh.cells_dict["quad"]
            qpoints = quad_gauss_points(ngp=4)
            w, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
            # Extract only x,y coordinates (remove z coordinate)
            mesh_points_2d = mesh.points[elements][:, :, :2]
            shape_val, shape_grad, jac_det = quad4(xi, eta, mesh_points_2d, return_jacobian=False)
            n_basis = 4
        else:
            raise ValueError("Mesh should have triangle or quad cells")
        
        # Store precomputed quantities
        self.register_buffer("weight", torch.from_numpy(w).float())
        self.register_buffer("elements", torch.from_numpy(elements).long())
        self.register_buffer("shape_val", torch.from_numpy(shape_val).float())
        self.register_buffer("shape_grad", torch.from_numpy(shape_grad).float())
        self.register_buffer("jac_det", torch.from_numpy(jac_det).float())
        self.register_buffer("JxW", (torch.abs(self.jac_det) * self.weight)[..., None])  # [n_element, n_quadrature, 1]
        
        # self.register_buffer("boundary_mask", torch.from_numpy(mesh.point_data["boundary_mask"]).bool())
        # self.weight = torch.from_numpy(w).float()
        # self.elements = torch.from_numpy(elements).long()
        # self.shape_val = torch.from_numpy(shape_val).float()
        # self.shape_grad = torch.from_numpy(shape_grad).float()
        # self.jac_det = torch.from_numpy(jac_det).float()
        # self.JxW = (torch.abs(self.jac_det) * self.weight)[..., None]  # [n_element, n_quadrature, 1]
        
        # Boundary conditions
        self.register_buffer("boundary_mask", torch.from_numpy(mesh.point_data["boundary_mask"]).bool())
        self.register_buffer("boundary_value", torch.from_numpy(mesh.point_data['boundary_value']).float())
        # self.boundary_mask = torch.from_numpy(mesh.point_data["boundary_mask"]).bool()
        # self.boundary_value = torch.from_numpy(mesh.point_data['boundary_value']).float()
        
        # Element topology
        n_elements, n_basis = self.elements.shape
        n_nodes = mesh.points.shape[0]
        self.n_elements = n_elements
        self.n_basis = n_basis
        self.n_nodes = n_nodes
        self.n_quadrature = self.weight.shape[0]
        
        # Element-to-node mapping for efficient assembly
        self.register_buffer("ele2node", torch.sparse_coo_tensor(
            torch.stack([
                self.elements.flatten(),
                torch.arange(n_elements * n_basis),
            ]), 
            torch.ones(n_elements * n_basis), 
            (n_nodes, n_basis * n_elements)
        ))
        # self.ele2node = torch.sparse_coo_tensor(
        #     torch.stack([
        #         self.elements.flatten(),
        #         torch.arange(n_elements * n_basis),
        #     ]), 
        #     torch.ones(n_elements * n_basis), 
        #     (n_nodes, n_basis * n_elements)
        # )

    def update_mesh(self, mesh):
        """Update mesh and recompute quantities"""
        self.mesh = mesh
        self.precompute(mesh)
    
    # def to(self, device):
    #     self.shape_val   = self.shape_val.to(device)
    #     self.shape_grad  = self.shape_grad.to(device)
    #     self.weight      = self.weight.to(device)
    #     self.elements    = self.elements.to(device)
    #     self.jac_det     = self.jac_det.to(device)
    #     self.JxW         = self.JxW.to(device)
    #     self.boundary_mask = self.boundary_mask.to(device)
    #     self.boundary_value= self.boundary_value.to(device)
    #     self.ele2node    = self.ele2node.to(device)


    def form(self, phi: torch.Tensor, basis: torch.Tensor, 
             gradphi: torch.Tensor, gradbasis: torch.Tensor) -> torch.Tensor:
        """
        Weak form of Poisson equation: ∫(a∇u⋅∇v - fv)dΩ = 0
        
        Parameters:
        -----------
            phi: torch.Tensor [n_quad]
                Solution values at quadrature points
            basis: torch.Tensor [n_quad, n_basis]
                Basis function values at quadrature points
            gradphi: torch.Tensor [n_quad, 2]
                Solution gradients at quadrature points
            gradbasis: torch.Tensor [n_quad, n_basis, 2]
                Basis function gradients at quadrature points
                
        Returns:
        --------
            torch.Tensor [n_quad, n_basis]
                Weak form residual contributions
        """
        # Diffusion term: a∇u⋅∇v
        diffusion = self.a * torch.sum(gradphi.unsqueeze(-2) * gradbasis, dim=-1)  # [n_quad, n_basis]
        
        # Source term: fv (assuming f is constant for simplicity)
        # In practice, f would be evaluated at quadrature points
        source = torch.ones_like(diffusion)  # Placeholder for f*v
        
        return diffusion - source

    def compute_residual(self, U: torch.Tensor, f_values: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Compute residual for Poisson equation using finite element assembly
        
        Efficient vectorized implementation using torch.vmap and torch.einsum
        Supports arbitrary batch dimensions through recursive vmap application
        
        Parameters:
        -----------
            U: torch.Tensor [..., n_nodes]
                Solution values at nodes (supports arbitrary batch dimensions)
            f_values: Optional[torch.Tensor] [..., n_nodes]
                Source values at nodes (if None, uses self.f or assumes unit source)
                
        Returns:
        --------
            torch.Tensor [..., n_nodes]
                Residual values (same shape as U)
        """
        self.to(U.device)
        self.to(U.dtype)
        
        
        def compute_residual_single(U_single, f_single=None):
            """
            Compute residual for single sample using vectorized operations
            Based on original efficient implementation from Equations/poisson.py
            
            Weak form: R^I = ∫(a∇u⋅∇N^I - f⋅N^I)dΩ
            
            Parameters:
            -----------
                U_single: torch.Tensor [n_nodes]
                    Solution values at nodes for single sample
                f_single: Optional[torch.Tensor] [n_nodes]
                    Source values at nodes for single sample
            """
            # Apply boundary conditions
            U_constraint = apply_zero_boundary(U_single, self.boundary_mask)
            
            # Get source values for this sample
            if f_single is not None:
                # Use provided f values
                pass  # f_single is already correct
            elif self.f is not None:
                # Evaluate source function at mesh nodes
                mesh_points = torch.from_numpy(self.mesh.points[:, :2]).to(U_single.device).type(U_single.dtype)
                f_single = self.f(mesh_points)
            else:
                f_single = torch.ones_like(U_single)
            
            # Apply boundary conditions to source as well
            f_constraint = apply_zero_boundary(f_single, self.boundary_mask)

            # Extract element node values
            elemU = U_constraint[self.elements]         # [n_element, n_basis]
            elemf = f_constraint[self.elements]         # [n_element, n_basis]
            # Compute solution and source values at quadrature points using einsum
            phi = torch.einsum('gb,eb->eg', self.shape_val, elemU)     # [n_element, n_quadrature]
            quadf = torch.einsum('gb,eb->eg', self.shape_val, elemf)   # [n_element, n_quadrature]
            gradphi = torch.einsum('egbd,eb->egd', self.shape_grad, elemU)  # [n_element, n_quadrature, 2]
            
            # Reshape for vmap processing: flatten all quadrature points across elements
            shp_val = self.shape_val.repeat((self.n_elements, 1, 1)).view((self.n_elements * self.n_quadrature, self.n_basis))  # [n_element*n_quadrature, n_basis]
            shp_grad = self.shape_grad.view((self.n_elements * self.n_quadrature, self.n_basis, 2))  # [n_element*n_quadrature, n_basis, 2]
            phi_flat = phi.flatten()  # [n_element*n_quadrature]
            gradphi_flat = gradphi.view((self.n_elements * self.n_quadrature, 2))  # [n_element*n_quadrature, 2]
            quadf_flat = quadf.flatten()  # [n_element*n_quadrature]
            
            def form(phi, gradphi, basis, grad_basis, f):
                """
                Weak form contribution at a single quadrature point
                R^I = a∇u⋅∇N^I - f⋅N^I
                
                Parameters:
                -----------
                    phi: torch.Tensor []
                        Solution value at quadrature point
                    gradphi: torch.Tensor [2]
                        Solution gradient at quadrature point
                    basis: torch.Tensor [n_basis]
                        Basis function values at quadrature point
                    grad_basis: torch.Tensor [n_basis, 2]
                        Basis function gradients at quadrature point
                    f: torch.Tensor []
                        Source value at quadrature point
                        
                Returns:
                --------
                    torch.Tensor [n_basis]
                        Residual contributions for this quadrature point
                """
                # Diffusion term: a∇u⋅∇N^I
                diffusion = self.a * (grad_basis @ gradphi)  # [n_basis]
                # Source term: f⋅N^I  
                source = f * basis  # [n_basis]
                return diffusion - source
            
            # Apply weak form at all quadrature points using vmap
            integral = torch.vmap(form)(phi_flat, gradphi_flat, shp_val, shp_grad, quadf_flat)  # [n_element*n_quadrature, n_basis]
            
            # Reshape back to element structure
            integral = integral.view((self.n_elements, self.n_quadrature, self.n_basis))  # [n_element, n_quadrature, n_basis]
            
            # Apply quadrature weights and jacobian
            integral = integral * self.JxW  # [n_element, n_quadrature, n_basis]
            
            # Integrate over quadrature points
            integral = integral.sum(1)  # [n_element, n_basis]
            # Global assembly using sparse matrix multiplication
            # Use torch.sparse.mm for better vmap compatibility  
            R = torch.sparse.mm(self.ele2node, integral.view((-1, 1))).squeeze()  # [n_nodes]
            
            # Apply boundary conditions to result
            R_constraint = apply_zero_boundary(R, self.boundary_mask)
            
            return R_constraint
        
        # Handle arbitrary batch dimensions using recursive vmap
        if f_values is not None:
            # Both U and f_values have batch dimensions - vmap over both
            fn = compute_residual_single
            for _ in range(U.dim() - 1):
                fn = torch.vmap(fn)
            return fn(U, f_values)
        else:
            # Only U has batch dimensions
            fn = compute_residual_single
            for _ in range(U.dim() - 1):
                fn = torch.vmap(fn)
            return fn(U)


def create_simple_poisson_equation(nx: int = 32, ny: int = 32, 
                                  use_dense_element: bool = False,
                                  device: str = "cpu",
                                  f: Optional[Callable] = None) -> PoissonEquation:
    """
    Create a simple Poisson equation on unit square [0,1]x[0,1]
    
    This is a helper function for backward compatibility with the old interface
    
    Parameters:
    -----------
        nx, ny: int
            Grid resolution
        use_dense_element: bool
            Whether to use dense element connectivity
        device: str
            Device to run on
        f: Optional[Callable]
            Source function
            
    Returns:
    --------
        PoissonEquation
            Configured Poisson equation solver
    """
    # For now, this will need to import from the original mesh generation
    # TODO: This should be replaced with the refactored mesh generation
    from Dataset import MeshGen
    
    # Create a simple configuration object
    class SimpleConfig:
        def __init__(self):
            self.nx = nx
            self.ny = ny
            self.use_dense_element = use_dense_element
            self.device = device
    
    config = SimpleConfig()
    mesh = MeshGen.init_mesh(config)
    
    return PoissonEquation(mesh, f=f)