"""
Heat Equation solver for TensorGalerkin

Equation: ∂u/∂t = a²Δu

This module implements the Galerkin residual computation for the heat equation
using finite element discretization with implicit time stepping.

Weak form: ∫ (u^n - u^{n-1})/dt · v dx + a² ∫ ∇u^n · ∇v dx = 0
Discrete form: (M + dt·a²·A)u^n = M·u^{n-1}
"""

import numpy as np
import scipy.sparse
import torch
from typing import Optional

from .base import EquationDatasetSequential
from ..discretization import tri3, quad4, tri_gauss_points, quad_gauss_points
from ..utils import apply_zero_boundary


class HeatEquation(EquationDatasetSequential):
    """
    Heat Equation solver using Galerkin finite element method.
    
    Equation: ∂u/∂t = a²Δu  in Ω
              u = 0         on ∂Ω (Dirichlet BC)
    
    Uses implicit Euler time stepping.
    """
    
    def __init__(self, mesh, a: float = 1.0, dt: float = 0.01, fast_mode: bool = False):
        """
        Initialize HeatEquation.
        
        Parameters:
        -----------
            mesh: meshio.Mesh
                The finite element mesh
            a: float
                Diffusion coefficient (default: 1.0)
            dt: float
                Time step (default: 0.01)
            fast_mode: bool
                Use fast matrix-based computation (default: False)
        """
        super().__init__()
        self.dt = dt
        self.a = a
        self.fast_mode = fast_mode
        self.update_mesh(mesh)
    
    def precompute(self, mesh):
        """Precompute finite element matrices from mesh."""
        if "triangle" in mesh.cells_dict.keys():
            elements = mesh.cells_dict["triangle"]
            qpoints = tri_gauss_points(ngp=1)
            w, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
            mesh_points_2d = mesh.points[elements][:, :, :2]
            shape_val, shape_grad, jac_det = tri3(xi, eta, mesh_points_2d)
            n_basis = 3
        elif "quad" in mesh.cells_dict.keys():
            elements = mesh.cells_dict["quad"]
            qpoints = quad_gauss_points(ngp=4)
            w, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
            mesh_points_2d = mesh.points[elements][:, :, :2]
            shape_val, shape_grad, jac_det = quad4(xi, eta, mesh_points_2d)
            n_basis = 4
        else:
            raise ValueError("Mesh should have triangle or quad cells")
        
        # Store precomputed quantities
        self.register_buffer("weight", torch.from_numpy(w).float())
        self.register_buffer("elements", torch.from_numpy(elements).long())
        self.register_buffer("shape_val", torch.from_numpy(shape_val).float())
        self.register_buffer("shape_grad", torch.from_numpy(shape_grad).float())
        self.register_buffer("jac_det", torch.from_numpy(jac_det).float())
        self.register_buffer("JxW", (torch.abs(self.jac_det) * self.weight)[..., None])
        
        # Boundary conditions
        self.register_buffer("boundary_mask", torch.from_numpy(mesh.point_data["boundary_mask"]).bool())
        self.register_buffer("boundary_value", torch.from_numpy(mesh.point_data['boundary_value']).float())
        
        # Element topology
        n_elements, n_basis = self.elements.shape
        n_nodes = mesh.points.shape[0]
        self.n_elements = n_elements
        self.n_basis = n_basis
        self.n_nodes = n_nodes
        self.n_quadrature = self.weight.shape[0]
        
        # Build edge connectivity for matrix assembly
        elem_u, elem_v = [], []
        for i in range(n_basis):
            for j in range(n_basis):
                elem_u.append(elements[:, i])
                elem_v.append(elements[:, j])
        elem_u, elem_v = np.stack(elem_u, -1).ravel(), np.stack(elem_v, -1).ravel()
        
        # Remove duplicated edges
        tmp = scipy.sparse.coo_matrix((
            np.ones_like(elem_u),
            (elem_u, elem_v),
        ), shape=(n_nodes, n_nodes)).tocsr().tocoo()
        edge_u, edge_v = tmp.row, tmp.col
        n_edges = len(edge_u)
        
        eids_csr = scipy.sparse.coo_matrix((
            np.arange(n_edges), (edge_u, edge_v)
        ), shape=(n_nodes, n_nodes)).tocsr()
        
        elem_eids = torch.from_numpy(np.array(eids_csr[elem_u, elem_v]).ravel())
        
        # Element-to-mesh edge mapping
        self.register_buffer("ele2msh_edge_indices", torch.stack([
            elem_eids,
            torch.arange(n_elements * n_basis * n_basis)
        ], 0))
        self.register_buffer("ele2msh_edge_values", torch.ones_like(elem_eids).double())
        
        ele2msh_edge = torch.sparse_coo_tensor(
            self.ele2msh_edge_indices,
            self.ele2msh_edge_values,
            (n_edges, n_elements * n_basis * n_basis)
        ).to_sparse_csr()
        
        # Element-to-node mapping
        self.register_buffer("ele2msh_node", torch.sparse_coo_tensor(
            torch.stack([
                self.elements.flatten(),
                torch.arange(n_elements * n_basis),
            ]),
            torch.ones(n_elements * n_basis).double(),
            (n_nodes, n_basis * n_elements)
        ).to_sparse_csr())
        
        # Precompute mass and stiffness matrices
        M_local = torch.einsum("qi,qj,eq->eqij", self.shape_val, self.shape_val, self.JxW.squeeze(2))
        A_local = torch.einsum("eqib,eqjb,eq->eqij", self.shape_grad, self.shape_grad, self.JxW.squeeze(2))
        K_local = M_local + self.dt * self.a * self.a * A_local
        
        M_local = M_local.sum(1)  # [n_elements, n_basis, n_basis]
        K_local = K_local.sum(1)
        
        M_vals = torch.sparse.mm(ele2msh_edge, M_local.reshape(-1, 1)).squeeze()
        K_vals = torch.sparse.mm(ele2msh_edge, K_local.reshape(-1, 1)).squeeze()
        
        edges = torch.from_numpy(np.stack([edge_u, edge_v], 0))
        
        self.register_buffer("M", torch.sparse_coo_tensor(edges, M_vals, (n_nodes, n_nodes)))
        self.register_buffer("K", torch.sparse_coo_tensor(edges, K_vals, (n_nodes, n_nodes)))
    
    def update_mesh(self, mesh):
        """Update the mesh and recompute all matrices."""
        self.mesh = mesh
        self.precompute(mesh)
    
    def compute_residual(self, U_t1: torch.Tensor, U_t2: torch.Tensor) -> torch.Tensor:
        """
        Compute Galerkin residual for heat equation.
        
        R^I = ∫ ((u^n - u^{n-1})/dt · N^I + a²∇u^n · ∇N^I) dx
        
        Parameters:
        -----------
            U_t1: torch.Tensor [..., n_nodes]
                Solution at previous time step (u^{n-1})
            U_t2: torch.Tensor [..., n_nodes]
                Solution at current time step (u^n)
                
        Returns:
        --------
            R: torch.Tensor [..., n_nodes]
                Residual at each node
        """
        assert U_t1.shape == U_t2.shape, f"U_t1 and U_t2 should have the same shape"
        assert U_t1.device == U_t2.device, f"U_t1 and U_t2 should be on the same device"
        
        def compute_residual_slow(U_t1, U_t2):
            U_t1_constraint = apply_zero_boundary(U_t1, self.boundary_mask)
            U_t2_constraint = apply_zero_boundary(U_t2, self.boundary_mask)
            
            elemU_t1 = U_t1_constraint[self.elements]
            elemU_t2 = U_t2_constraint[self.elements]
            
            phi_t1 = torch.einsum('gb,eb->eg', self.shape_val, elemU_t1)
            phi_t2 = torch.einsum('gb,eb->eg', self.shape_val, elemU_t2)
            gradphi_t2 = torch.einsum('egbd,eb->egd', self.shape_grad, elemU_t2)
            
            shp_val = self.shape_val.repeat((self.n_elements, 1, 1)).view(
                (self.n_elements * self.n_quadrature, self.n_basis))
            shp_grad = self.shape_grad.view((self.n_elements * self.n_quadrature, self.n_basis, 2))
            phi_t1 = phi_t1.flatten()
            phi_t2 = phi_t2.flatten()
            gradphi_t2 = gradphi_t2.view((self.n_elements * self.n_quadrature, 2))
            
            def form(phi_t2, phi_t1, gradphi_t2, basis, grad_basis):
                """R^I = (u^n - u^{n-1})/dt · N^I + a²∇u^n · ∇N^I"""
                return (phi_t2 - phi_t1) / self.dt * basis + self.a * self.a * grad_basis @ gradphi_t2
            
            integral = torch.vmap(form)(phi_t2, phi_t1, gradphi_t2, shp_val, shp_grad)
            integral = integral.view((self.n_elements, self.n_quadrature, self.n_basis))
            integral = integral * self.JxW
            integral = integral.sum(1)
            
            R = torch.sparse.mm(self.ele2msh_node, integral.view((-1, 1))).squeeze()
            R_constraint = apply_zero_boundary(R, self.boundary_mask)
            
            return R_constraint
        
        def compute_residual_fast(U_t1, U_t2):
            U_t1_constraint = apply_zero_boundary(U_t1, self.boundary_mask)
            U_t2_constraint = apply_zero_boundary(U_t2, self.boundary_mask)
            
            R = (torch.sparse.mm(self.K, U_t2_constraint.view((-1, 1))) - 
                 torch.sparse.mm(self.M, U_t1_constraint.view((-1, 1)))).squeeze()
            R_constraint = apply_zero_boundary(R, self.boundary_mask)
            
            return R_constraint
        
        fn = compute_residual_fast if self.fast_mode else compute_residual_slow
        for _ in range(U_t1.dim() - 1):
            fn = torch.vmap(fn)
        R = fn(U_t1, U_t2)
        
        return R
    
    def compute_energy(self, Us: torch.Tensor) -> torch.Tensor:
        """
        Compute energy of solution: E = ∫ u² dx
        
        Parameters:
        -----------
            Us: torch.Tensor [..., n_nodes]
                Solution values
                
        Returns:
        --------
            energy: torch.Tensor [...]
                Energy value
        """
        energy = (Us * torch.sparse.mm(self.M, Us.view(-1, 1)).view(Us.shape)).sum(-1)
        return energy

