"""
Helmholtz Equation solver for TensorGalerkin

Equation: -Δu - k²u = f

This module implements the Galerkin residual computation for the Helmholtz equation
using finite element discretization.

Weak form: ∫ ∇u · ∇v dx - k² ∫ u · v dx = ∫ f · v dx
Discrete form: (A - k²M)u = b

where:
    A: Stiffness matrix, A_IJ = ∫ ∇N^I · ∇N^J dx
    M: Mass matrix, M_IJ = ∫ N^I · N^J dx  
    b: Load vector, b_I = ∫ f · N^I dx
"""

import torch
import numpy as np
from typing import Union, Optional

from .base import EquationDatasetStatic
from ..discretization import tri3, quad4, tri_gauss_points, quad_gauss_points
from ..utils import apply_zero_boundary


class HelmholtzEquation(EquationDatasetStatic):
    """
    Helmholtz Equation solver using Galerkin finite element method.
    
    Equation: -Δu - k²u = f  in Ω
              u = 0         on ∂Ω (Dirichlet BC)
    
    Supports both constant and spatially-varying wavenumber k.
    """
    
    def __init__(self, mesh, k: Union[float, torch.Tensor] = 1.0):
        """
        Initialize HelmholtzEquation.
        
        Parameters:
        -----------
            mesh: meshio.Mesh
                The finite element mesh
            k: float or torch.Tensor
                Wavenumber. Can be:
                - float: constant wavenumber
                - Tensor [n_nodes]: spatially varying wavenumber k(x)
        """
        super().__init__()
        self.k = k
        self.update_mesh(mesh)
    
    def precompute(self, mesh):
        """Precompute finite element matrices from mesh."""
        if "triangle" in mesh.cells_dict.keys():
            elements = mesh.cells_dict["triangle"]
            qpoints = tri_gauss_points(ngp=4)
            w, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
            mesh_points_2d = mesh.points[elements][:, :, :2]
            shape_val, shape_grad, jac_det = tri3(xi, eta, mesh_points_2d)
            n_basis = 3
        elif "quad" in mesh.cells_dict.keys():
            elements = mesh.cells_dict["quad"]
            qpoints = quad_gauss_points(ngp=4)
            w, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
            mesh_points_2d = mesh.points[elements][:, :, :2]
            shape_val, shape_grad, jac_det = quad4(xi, eta, mesh_points_2d)
            n_basis = 4
        else:
            raise ValueError("Mesh should have triangle or quad cells")
        
        # Store precomputed quantities using register_buffer for nn.Module compatibility
        self.register_buffer("weight", torch.from_numpy(w).float())
        self.register_buffer("elements", torch.from_numpy(elements).long())
        self.register_buffer("shape_val", torch.from_numpy(shape_val).float())
        self.register_buffer("shape_grad", torch.from_numpy(shape_grad).float())
        self.register_buffer("jac_det", torch.from_numpy(jac_det).float())
        self.register_buffer("JxW", (torch.abs(self.jac_det) * self.weight)[..., None])
        
        # Boundary conditions
        self.register_buffer("boundary_mask", torch.from_numpy(mesh.point_data["boundary_mask"]).bool())
        self.register_buffer("boundary_value", torch.from_numpy(mesh.point_data['boundary_value']).float())
        
        # Element topology
        n_elements, n_basis = self.elements.shape
        n_nodes = mesh.points.shape[0]
        self.n_elements = n_elements
        self.n_basis = n_basis
        self.n_nodes = n_nodes
        self.n_quadrature = self.weight.shape[0]
        
        # Element-to-node mapping
        self.register_buffer("ele2node", torch.sparse_coo_tensor(
            torch.stack([
                self.elements.flatten(),
                torch.arange(n_elements * n_basis),
            ]),
            torch.ones(n_elements * n_basis),
            (n_nodes, n_basis * n_elements)
        ))
        
        # Precompute stiffness matrix A and mass matrix M
        self._assemble_stiffness_matrix()
        self._assemble_mass_matrix()
    
    def _assemble_stiffness_matrix(self):
        """
        Assemble the stiffness matrix A (sparse).
        
        A_IJ = ∫_Ω ∇N^I · ∇N^J dx
        """
        rows = []
        cols = []
        vals = []
        
        for e in range(self.n_elements):
            elem_nodes = self.elements[e]
            for g in range(self.n_quadrature):
                JxW_val = self.JxW[e, g, 0].item()
                grad = self.shape_grad[e, g]
                
                for i in range(self.n_basis):
                    for j in range(self.n_basis):
                        val = (grad[i, 0] * grad[j, 0] + grad[i, 1] * grad[j, 1]).item() * JxW_val
                        rows.append(elem_nodes[i].item())
                        cols.append(elem_nodes[j].item())
                        vals.append(val)
        
        A = torch.sparse_coo_tensor(
            torch.tensor([rows, cols], dtype=torch.long),
            torch.tensor(vals, dtype=torch.float64),
            (self.n_nodes, self.n_nodes)
        ).coalesce()
        
        self.register_buffer("A", A)
        self.A_csr = A.to_sparse_csr()
    
    def _assemble_mass_matrix(self):
        """
        Assemble the mass matrix M (sparse).
        
        M_IJ = ∫_Ω N^I · N^J dx
        """
        rows = []
        cols = []
        vals = []
        
        for e in range(self.n_elements):
            elem_nodes = self.elements[e]
            for g in range(self.n_quadrature):
                JxW_val = self.JxW[e, g, 0].item()
                N = self.shape_val[g]
                
                for i in range(self.n_basis):
                    for j in range(self.n_basis):
                        val = (N[i] * N[j]).item() * JxW_val
                        rows.append(elem_nodes[i].item())
                        cols.append(elem_nodes[j].item())
                        vals.append(val)
        
        M = torch.sparse_coo_tensor(
            torch.tensor([rows, cols], dtype=torch.long),
            torch.tensor(vals, dtype=torch.float64),
            (self.n_nodes, self.n_nodes)
        ).coalesce()
        
        self.register_buffer("M", M)
        self.M_csr = M.to_sparse_csr()
    
    def assemble_load_vector(self, f: torch.Tensor) -> torch.Tensor:
        """
        Assemble the load vector b for the Helmholtz equation.
        
        b_I = ∫_Ω f · N^I dx
        
        Parameters:
        -----------
            f: torch.Tensor [n_nodes] or [batch, n_nodes]
                Source function values at mesh nodes
                
        Returns:
        --------
            b: torch.Tensor [n_nodes] or [batch, n_nodes]
                Load vector
        """
        is_batched = f.dim() > 1
        
        if is_batched:
            batch_size = f.shape[0]
            elemf = f[:, self.elements]
            
            quadf = torch.einsum('gi,bei->beg', self.shape_val.to(f.dtype), elemf)
            JxW_squeezed = self.JxW.squeeze(-1).to(f.dtype)
            integral = torch.einsum('beg,gi,eg->bei', quadf, self.shape_val.to(f.dtype), JxW_squeezed)
            
            load_vec = torch.zeros(batch_size, self.n_nodes, device=f.device, dtype=f.dtype)
            
            for i in range(self.n_basis):
                idx = self.elements[:, i].unsqueeze(0).expand(batch_size, -1)
                load_vec.scatter_add_(1, idx, integral[:, :, i])
            
            return load_vec
        else:
            elemf = f[self.elements]
            
            quadf = torch.einsum('gb,eb->eg', self.shape_val.to(f.dtype), elemf)
            JxW_squeezed = self.JxW.squeeze(-1).to(f.dtype)
            integral = torch.einsum('eg,gi,eg->ei', quadf, self.shape_val.to(f.dtype), JxW_squeezed)
            
            b = torch.zeros(self.n_nodes, device=f.device, dtype=f.dtype)
            
            for i in range(self.n_basis):
                b.scatter_add_(0, self.elements[:, i], integral[:, i])
            
            return b
    
    def assemble_weighted_mass_matrix(self, k2: torch.Tensor) -> torch.Tensor:
        """
        Assemble weighted mass matrix for spatially varying k².
        
        M_k_IJ = ∫_Ω k²(x) N^I · N^J dx
        
        Parameters:
        -----------
            k2: torch.Tensor [n_nodes]
                Squared wavenumber at each node
                
        Returns:
        --------
            M_k: torch.sparse.Tensor [n_nodes, n_nodes]
                Weighted mass matrix
        """
        rows = []
        cols = []
        vals = []
        
        for e in range(self.n_elements):
            elem_nodes = self.elements[e]
            elem_k2 = k2[elem_nodes]
            
            for g in range(self.n_quadrature):
                JxW_val = self.JxW[e, g, 0].item()
                N = self.shape_val[g]
                
                k2_q = (N.to(k2.dtype) * elem_k2).sum().item()
                
                for i in range(self.n_basis):
                    for j in range(self.n_basis):
                        val = k2_q * (N[i] * N[j]).item() * JxW_val
                        rows.append(elem_nodes[i].item())
                        cols.append(elem_nodes[j].item())
                        vals.append(val)
        
        M_k = torch.sparse_coo_tensor(
            torch.tensor([rows, cols], dtype=torch.long),
            torch.tensor(vals, dtype=torch.float64),
            (self.n_nodes, self.n_nodes)
        ).coalesce()
        
        return M_k.to_sparse_csr()
    
    def compute_residual_fast(self, U: torch.Tensor, f: torch.Tensor, 
                               k: Union[float, torch.Tensor, None] = None) -> torch.Tensor:
        """
        Fast residual computation using precomputed matrices.
        
        R = (A - k²M)u - b
        
        Parameters:
        -----------
            U: torch.Tensor [n_nodes] or [batch, n_nodes]
                Solution values at mesh nodes
            f: torch.Tensor [n_nodes] or [batch, n_nodes]
                Source function values at mesh nodes
            k: float, torch.Tensor [n_nodes], or None
                Wavenumber. If None, uses self.k
                
        Returns:
        --------
            R: torch.Tensor [n_nodes] or [batch, n_nodes]
                Residual at each node
        """
        if k is None:
            k = self.k
        
        is_batched = U.dim() > 1
        is_k_spatial = isinstance(k, torch.Tensor) and k.numel() > 1
        
        # Ensure matrices are on correct device
        A_csr = self.A_csr.to(U.device)
        M_csr = self.M_csr.to(U.device)
        
        if is_batched:
            U_bc = apply_zero_boundary(U, self.boundary_mask)
            
            A_typed = A_csr.to(U_bc.dtype)
            M_typed = M_csr.to(U_bc.dtype)
            
            Au = torch.sparse.mm(A_typed, U_bc.T).T
            
            if is_k_spatial:
                k2 = k ** 2
                M_k = self.assemble_weighted_mass_matrix(k2).to(U_bc.dtype).to(U.device)
                k2Mu = torch.sparse.mm(M_k, U_bc.T).T
            else:
                k2 = k ** 2
                Mu = torch.sparse.mm(M_typed, U_bc.T).T
                k2Mu = k2 * Mu
            
            b = self.assemble_load_vector(f)
            R = Au - k2Mu - b
            R_bc = apply_zero_boundary(R, self.boundary_mask)
            
            return R_bc
        else:
            U_bc = apply_zero_boundary(U, self.boundary_mask)
            
            A_typed = A_csr.to(U_bc.dtype)
            M_typed = M_csr.to(U_bc.dtype)
            
            Au = torch.sparse.mm(A_typed, U_bc.unsqueeze(-1)).squeeze(-1)
            
            if is_k_spatial:
                k2 = k ** 2
                M_k = self.assemble_weighted_mass_matrix(k2).to(U_bc.dtype).to(U.device)
                k2Mu = torch.sparse.mm(M_k, U_bc.unsqueeze(-1)).squeeze(-1)
            else:
                k2 = k ** 2
                Mu = torch.sparse.mm(M_typed, U_bc.unsqueeze(-1)).squeeze(-1)
                k2Mu = k2 * Mu
            
            b = self.assemble_load_vector(f)
            R = Au - k2Mu - b
            R_bc = apply_zero_boundary(R, self.boundary_mask)
            
            return R_bc
    
    def form(self, phi: torch.Tensor, basis: torch.Tensor,
             gradphi: torch.Tensor, gradbasis: torch.Tensor) -> torch.Tensor:
        """
        Weak form of Helmholtz equation: ∫(∇u⋅∇v - k²uv - fv)dΩ = 0
        
        Note: This is a simplified form without the source term f.
        """
        k2 = self.k ** 2 if not isinstance(self.k, torch.Tensor) else self.k ** 2
        return (gradbasis @ gradphi).sum(-1) - k2 * phi * basis
    
    def update_mesh(self, mesh):
        """Update the mesh and recompute all matrices."""
        self.mesh = mesh
        self.precompute(mesh)
    
    def set_wavenumber(self, k: Union[float, torch.Tensor]):
        """Set the wavenumber k."""
        self.k = k

