"""
Wave Equation solver for TensorGalerkin

Equation: ∂²u/∂t² = c²Δu

This module implements the Galerkin residual computation for the wave equation
using finite element discretization with central difference time stepping.

Weak form: ∫ (u^{n+1} - 2u^n + u^{n-1})/dt² · v dx + c² ∫ ∇u^n · ∇v dx = 0
Discrete form: M·u^{n+1} = 2M·u^n - M·u^{n-1} - dt²·c²·A·u^n
"""

import numpy as np
import scipy.sparse
import torch
from typing import Optional, Union

from .base import EquationDatasetSequential
from ..discretization import tri3, quad4, tri_gauss_points, quad_gauss_points
from ..utils import apply_zero_boundary, apply_dirichlet_boundary


class WaveEquation(EquationDatasetSequential):
    """
    Wave Equation solver using Galerkin finite element method.
    
    Equation: ∂²u/∂t² = c²Δu  in Ω
              u = 0           on ∂Ω (Dirichlet BC)
    
    Supports both constant and spatially-varying wave speed c.
    Uses central difference time stepping (explicit).
    """
    
    def __init__(self, mesh, c: Union[float, torch.Tensor] = 1.0, 
                 dt: float = 0.01, fast_mode: bool = False, dataset=None):
        """
        Initialize WaveEquation.
        
        Parameters:
        -----------
            mesh: meshio.Mesh
                The finite element mesh
            c: float or torch.Tensor
                Wave speed. Can be:
                - float: constant wave speed
                - Tensor [n_nodes]: spatially varying wave speed c(x)
            dt: float
                Time step (default: 0.01)
            fast_mode: bool
                Use fast matrix-based computation (default: False)
            dataset: Optional
                Parametric dataset for dynamic wave speed computation
        """
        super().__init__()
        self.dt = dt
        self.c = c
        self.fast_mode = fast_mode
        self.dataset = dataset
        self.update_mesh(mesh)
    
    def precompute(self, mesh):
        """Precompute finite element matrices from mesh."""
        if "triangle" in mesh.cells_dict.keys():
            elements = mesh.cells_dict["triangle"]
            qpoints = tri_gauss_points(ngp=1)
            w, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
            mesh_points_2d = mesh.points[elements][:, :, :2]
            shape_val, shape_grad, jac_det = tri3(xi, eta, mesh_points_2d)
            n_basis = 3
        elif "quad" in mesh.cells_dict.keys():
            elements = mesh.cells_dict["quad"]
            qpoints = quad_gauss_points(ngp=4)
            w, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
            mesh_points_2d = mesh.points[elements][:, :, :2]
            shape_val, shape_grad, jac_det = quad4(xi, eta, mesh_points_2d)
            n_basis = 4
        else:
            raise ValueError("Mesh should have triangle or quad cells")
        
        # Determine dtype from c
        if isinstance(self.c, torch.Tensor):
            dtype = self.c.dtype
        else:
            dtype = torch.float64
        
        # Store precomputed quantities
        self.register_buffer("weight", torch.from_numpy(w).to(dtype))
        self.register_buffer("elements", torch.from_numpy(elements).long())
        self.register_buffer("shape_val", torch.from_numpy(shape_val).to(dtype))
        self.register_buffer("shape_grad", torch.from_numpy(shape_grad).to(dtype))
        self.register_buffer("jac_det", torch.from_numpy(jac_det).to(dtype))
        self.register_buffer("JxW", (torch.abs(self.jac_det) * self.weight)[..., None].to(dtype))
        
        # Boundary conditions
        self.register_buffer("boundary_mask", torch.from_numpy(mesh.point_data["boundary_mask"]).bool())
        self.register_buffer("boundary_value", torch.from_numpy(mesh.point_data['boundary_value']).to(dtype))
        
        # Mesh points
        self.register_buffer("points", torch.from_numpy(mesh.points.astype(float)).to(dtype))
        
        # Element topology
        n_elements, n_basis = self.elements.shape
        n_nodes = mesh.points.shape[0]
        self.n_elements = n_elements
        self.n_basis = n_basis
        self.n_nodes = n_nodes
        self.n_points = n_nodes
        self.n_quadrature = self.weight.shape[0]
        
        # Build edge connectivity for matrix assembly
        elem_u, elem_v = [], []
        for i in range(n_basis):
            for j in range(n_basis):
                elem_u.append(elements[:, i])
                elem_v.append(elements[:, j])
        elem_u, elem_v = np.stack(elem_u, -1).ravel(), np.stack(elem_v, -1).ravel()
        
        # Remove duplicated edges
        tmp = scipy.sparse.coo_matrix((
            np.ones_like(elem_u),
            (elem_u, elem_v),
        ), shape=(n_nodes, n_nodes)).tocsr().tocoo()
        edge_u, edge_v = tmp.row, tmp.col
        num_edges = len(edge_u)
        
        eids_csr = scipy.sparse.coo_matrix((
            np.arange(num_edges), (edge_u, edge_v)
        ), shape=(n_nodes, n_nodes)).tocsr()
        
        elem_eids = torch.from_numpy(np.array(eids_csr[elem_u, elem_v]).ravel())
        
        # Element-to-mesh edge mapping
        ele2msh_edge = torch.sparse_coo_tensor(
            torch.stack([elem_eids, torch.arange(n_elements * n_basis * n_basis)]),
            torch.ones(n_elements * n_basis * n_basis).to(dtype),
            (num_edges, n_elements * n_basis * n_basis)
        ).to_sparse_csr()
        self.register_buffer("ele2msh_edge_indices", ele2msh_edge.crow_indices())
        self.register_buffer("ele2msh_edge_col_indices", ele2msh_edge.col_indices())
        self.register_buffer("ele2msh_edge_values", ele2msh_edge.values())
        self.ele2msh_edge = ele2msh_edge
        
        self.register_buffer("edges", torch.from_numpy(np.stack([edge_u, edge_v], 0)))
        
        # Element-to-node mapping
        ele2msh_node = torch.sparse_coo_tensor(
            torch.stack([
                self.elements.flatten(),
                torch.arange(n_elements * n_basis, device=self.elements.device),
            ]),
            torch.ones(n_elements * n_basis, device=self.elements.device).to(dtype),
            (n_nodes, n_basis * n_elements)
        )
        self.register_buffer("ele2msh_node_indices", ele2msh_node._indices())
        self.register_buffer("ele2msh_node_values", ele2msh_node._values())
        self.ele2msh_node = ele2msh_node
        
        # Local mass matrix
        M = torch.einsum("qi,qj,eq->eqij", self.shape_val, self.shape_val, self.JxW.squeeze(2))
        M = M.sum(1)  # [n_elements, n_basis, n_basis]
        self.M = M
        
        # Compute global stiffness-related matrices
        self.compute_K(self.c)
        
        # Build A1, A3 matrices for fast mode
        A3 = M
        A1 = -M
        
        A1_vals = torch.sparse.mm(self.ele2msh_edge, A1.view((-1, 1)))
        A3_vals = torch.sparse.mm(self.ele2msh_edge, A3.view((-1, 1)))
        
        self.A1 = torch.sparse_coo_tensor(self.edges, A1_vals.flatten(), (n_nodes, n_nodes)).to_sparse_csr()
        self.A3 = torch.sparse_coo_tensor(self.edges, A3_vals.flatten(), (n_nodes, n_nodes)).to_sparse_csr()
        
        # Global mass matrix
        M_global_vals = torch.sparse.mm(self.ele2msh_edge, M.reshape(-1, 1)).squeeze()
        self.register_buffer("M_global", torch.sparse_coo_tensor(self.edges, M_global_vals, (n_nodes, n_nodes)))
    
    def compute_K(self, c=None):
        """
        Compute stiffness-related matrices for given wave speed.
        
        Parameters:
        -----------
            c: float or torch.Tensor [n_nodes], optional
                Wave speed. If None, uses dataset or self.c
        """
        def _compute_Kc(c):
            if isinstance(c, (int, float)) or (isinstance(c, torch.Tensor) and c.numel() == 1):
                K = c * c * torch.einsum("eqib,eqjb,eq->eqij", self.shape_grad, self.shape_grad, self.JxW.squeeze(2))
            else:
                c_elem = c[self.elements]
                c_quad = torch.einsum("eb,qb->eq", c_elem, self.shape_val)
                K = torch.einsum("eq,eq,eqib,eqjb,eq->eqij", c_quad, c_quad, self.shape_grad, self.shape_grad, self.JxW.squeeze(2))
            return K, c
        
        if c is not None:
            pass
        elif self.dataset is not None:
            c = self.dataset.initial_prop_speed(self.points)
        else:
            c = self.c
        
        K, c = _compute_Kc(c)
        self.c = c
        
        K = K.sum(1)  # [n_elements, n_basis, n_basis]
        
        # A2 = 2M - dt²K
        A2 = 2 * self.M - self.dt * self.dt * K
        A2_vals = torch.sparse.mm(self.ele2msh_edge, A2.view((-1, 1)).to(self.ele2msh_edge.dtype))
        self.A2 = torch.sparse_coo_tensor(self.edges, A2_vals.flatten(), (self.n_nodes, self.n_nodes)).to_sparse_csr()
        
        # Global stiffness matrix
        K_global_vals = torch.sparse.mm(self.ele2msh_edge, K.reshape(-1, 1)).squeeze()
        self.register_buffer("K_global", torch.sparse_coo_tensor(self.edges, K_global_vals, (self.n_nodes, self.n_nodes)))
    
    def update_mesh(self, mesh):
        """Update the mesh and recompute all matrices."""
        self.mesh = mesh
        if self.c is not None:
            self.precompute(mesh)
    
    def compute_residual(self, U_t1: torch.Tensor, U_t2: torch.Tensor, 
                         U_t3: torch.Tensor, c: Optional[Union[float, torch.Tensor]] = None) -> torch.Tensor:
        """
        Compute Galerkin residual for wave equation.
        
        R^I = ∫ ((u^{n+1} - 2u^n + u^{n-1})/dt² · N^I + c²∇u^n · ∇N^I) dx
        
        Parameters:
        -----------
            U_t1: torch.Tensor [..., n_nodes]
                Solution at time step n-1
            U_t2: torch.Tensor [..., n_nodes]
                Solution at time step n
            U_t3: torch.Tensor [..., n_nodes]
                Solution at time step n+1
            c: float, torch.Tensor, or None
                Wave speed. If None, uses dataset or self.c
                
        Returns:
        --------
            R: torch.Tensor [..., n_nodes]
                Residual at each node
        """
        assert U_t1.shape == U_t2.shape == U_t3.shape, "All U tensors should have the same shape"
        assert U_t1.device == U_t2.device == U_t3.device, "All U tensors should be on the same device"
        
        if not hasattr(self, "shape_val"):
            assert c is not None, "c must be provided when not precomputed"
            self.c = c
            self.precompute(self.mesh)
        
        # Recompute K if needed
        if self.fast_mode and (c is not None or self.dataset is not None):
            self.compute_K(c=c)
        
        def compute_residual_slow(U_t1, U_t2, U_t3):
            U_t1_constraint = apply_dirichlet_boundary(U_t1, self.boundary_mask, self.boundary_value)
            U_t2_constraint = apply_dirichlet_boundary(U_t2, self.boundary_mask, self.boundary_value)
            U_t3_constraint = apply_dirichlet_boundary(U_t3, self.boundary_mask, self.boundary_value)
            
            elemU_t1 = U_t1_constraint[self.elements]
            elemU_t2 = U_t2_constraint[self.elements]
            elemU_t3 = U_t3_constraint[self.elements]
            
            U_t1_q = torch.einsum('gb,eb->eg', self.shape_val, elemU_t1)
            U_t2_q = torch.einsum('gb,eb->eg', self.shape_val, elemU_t2)
            U_t3_q = torch.einsum('gb,eb->eg', self.shape_val, elemU_t3)
            gradU_t2 = torch.einsum('egbd,eb->egd', self.shape_grad, elemU_t2)
            
            shp_val = self.shape_val.repeat((self.n_elements, 1, 1)).view(
                (self.n_elements * self.n_quadrature, self.n_basis))
            shp_grad = self.shape_grad.view((self.n_elements * self.n_quadrature, self.n_basis, 2))
            U_t1_q = U_t1_q.flatten()
            U_t2_q = U_t2_q.flatten()
            U_t3_q = U_t3_q.flatten()
            gradU_t2 = gradU_t2.view((self.n_elements * self.n_quadrature, 2))
            
            # Get effective wave speed
            _c = self._c if hasattr(self, '_c') else self.c
            
            def form(U_t1, U_t2, U_t3, gradU_t2, basis, grad_basis):
                c_val = _c if isinstance(_c, (int, float)) else _c
                return (U_t3 - 2 * U_t2 + U_t1) / self.dt / self.dt * basis + c_val * c_val * grad_basis @ gradU_t2
            
            def form_c(U_t1, U_t2, U_t3, gradU_t2, basis, grad_basis, c_q):
                return (U_t3 - 2 * U_t2 + U_t1) / self.dt / self.dt * basis + c_q * c_q * grad_basis @ gradU_t2
            
            if isinstance(_c, torch.Tensor) and _c.dim() == 1:
                c_elem = _c[self.elements]
                c_quad = torch.einsum("eb,qb->eq", c_elem, self.shape_val)
                c_flat = c_quad.flatten()
                integral = torch.vmap(form_c)(U_t1_q, U_t2_q, U_t3_q, gradU_t2, shp_val, shp_grad, c_flat)
            else:
                integral = torch.vmap(form)(U_t1_q, U_t2_q, U_t3_q, gradU_t2, shp_val, shp_grad)
            
            integral = integral.view((self.n_elements, self.n_quadrature, self.n_basis))
            integral = integral * self.JxW
            integral = integral.sum(1)
            
            R = torch.sparse.mm(self.ele2msh_node, integral.view((-1, 1))).squeeze()
            R_constraint = apply_zero_boundary(R, self.boundary_mask)
            
            return R_constraint
        
        def compute_residual_fast(U_t1, U_t2, U_t3):
            U_t1_constraint = apply_dirichlet_boundary(U_t1, self.boundary_mask, self.boundary_value)
            U_t2_constraint = apply_dirichlet_boundary(U_t2, self.boundary_mask, self.boundary_value)
            U_t3_constraint = apply_dirichlet_boundary(U_t3, self.boundary_mask, self.boundary_value)
            
            R = (torch.sparse.mm(self.A3, U_t3_constraint.view((-1, 1))) - 
                 torch.sparse.mm(self.A2, U_t2_constraint.view((-1, 1))) - 
                 torch.sparse.mm(self.A1, U_t1_constraint.view((-1, 1)))).squeeze()
            R = apply_zero_boundary(R, self.boundary_mask)
            
            return R
        
        # Set effective wave speed
        if c is not None:
            self._c = c
        elif self.dataset is not None:
            self._c = self.dataset.initial_prop_speed(self.points)
        else:
            self._c = self.c
        
        # Choose computation method and handle batching
        if isinstance(self._c, (int, float)) or (isinstance(self._c, torch.Tensor) and self._c.dim() <= 1):
            fn = compute_residual_fast if self.fast_mode else compute_residual_slow
            for _ in range(U_t1.dim() - 1):
                fn = torch.vmap(fn)
            R = fn(U_t1, U_t2, U_t3)
        else:
            raise ValueError(f"c should be a scalar or 1D tensor, got shape {self._c.shape}")
        
        return R
    
    def compute_energy(self, Us: torch.Tensor) -> tuple:
        """
        Compute total energy (kinetic + potential).
        
        Parameters:
        -----------
            Us: torch.Tensor [n_timesteps, n_nodes]
                Solution time series
                
        Returns:
        --------
            energy: torch.Tensor [n_timesteps-2]
                Total energy
            kinetic_energy: torch.Tensor [n_timesteps-2]
                Kinetic energy
            potential_energy: torch.Tensor [n_timesteps-2]
                Potential energy
        """
        Vs_e = (Us[2:, :] - Us[:-2, :]) / 2 / self.dt
        Us_e = Us[1:-1, :]
        
        kinetic_energy = (0.5 * Vs_e * torch.sparse.mm(self.M_global, Vs_e.T).T).sum(1)
        potential_energy = (0.5 * Us_e * torch.sparse.mm(self.K_global, Us_e.T).T).sum(1)
        energy = kinetic_energy + potential_energy
        
        return energy, kinetic_energy, potential_energy
    
    def compute_sympletic(self, Us: torch.Tensor) -> tuple:
        """
        Compute symplectic variables (velocity and position).
        
        Parameters:
        -----------
            Us: torch.Tensor [n_timesteps, n_nodes]
                Solution time series
                
        Returns:
        --------
            Vs: torch.Tensor [n_timesteps-2, n_nodes]
                Velocity field
            Us: torch.Tensor [n_timesteps-2, n_nodes]
                Position field
        """
        Vs_e = (Us[2:, :] - Us[:-2, :]) / 2 / self.dt
        Us_e = Us[1:-1, :]
        return Vs_e, Us_e

