"""
Heat equation data generators for TensorGalerkin

Solves the heat equation:
    ∂u/∂t = D * Δu  in Ω
    u = g           on ∂Ω (Dirichlet boundary)
    u(x, 0) = u0    (initial condition)
"""

import numpy as np
import scipy.sparse
import scipy.sparse.linalg
from typing import Union, List
from tqdm import tqdm

from ...discretization import tri3, quad4, tri_gauss_points, quad_gauss_points


class HeatGen:
    """Data generators for heat equation problems"""
    
    class MultiAnalytical:
        """Multi-mode analytical solutions for heat equation"""
        
        @staticmethod
        def initial_condition(points: np.ndarray, mu: np.ndarray) -> np.ndarray:
            """
            Generate the heat equation initial condition at each point in the domain.
            
            The initial condition is a superposition of sinusoidal modes:
            u0(x,y) = -Σ_m (μ_m * sin(πmx) * sin(πmy) / √m / d)
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                mu: np.ndarray (d,) or (N, d)
                    Coefficients for the analytical solution.
                    d is the number of Fourier modes.
                    
            Returns:
            --------
                u0: np.ndarray (n_points,) or (N, n_points)
                    Initial condition values
            """
            d = mu.shape[-1]
            m = np.arange(1, d + 1)
            
            if len(mu.shape) == 1:
                mu = mu[None, :]  # (1, d)
                m = m[None, :]    # (1, d)
                x, y = points[:, 0][:, None], points[:, 1][:, None]  # (n_points, 1)
            else:
                mu = mu[:, None, ...]    # (N, 1, d)
                m = m[None, None, ...]   # (1, 1, d)
                x, y = points[:, 0][None, :, None], points[:, 1][None, :, None]  # (1, n_points, 1)
            
            u0 = -(mu * np.sin(np.pi * m * x) * np.sin(np.pi * m * y) / np.sqrt(m) / d).sum(-1)
            return u0
        
        @staticmethod
        def solution(points: np.ndarray, mu: np.ndarray, t: float) -> np.ndarray:
            """
            Generate the analytical solution for heat equation at time t.
            
            The solution is:
            u(x,y,t) = -Σ_m (μ_m * sin(πmx) * sin(πmy) * exp(-2m²π²t) / √m / d)
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                mu: np.ndarray (d,) or (N, d)
                    Coefficients for the analytical solution
                t: float
                    Time at which to evaluate the solution
                    
            Returns:
            --------
                ut: np.ndarray (n_points,) or (N, n_points)
                    Solution values at time t
            """
            d = mu.shape[-1]
            m = np.arange(1, d + 1)
            
            if len(mu.shape) == 1:
                mu = mu[None, ...]  # (1, d)
                m = m[None, ...]    # (1, d)
                x, y = points[:, 0][:, None], points[:, 1][:, None]  # (n_points, 1)
            else:
                mu = mu[:, None, ...]    # (N, 1, d)
                m = m[None, None, ...]   # (1, 1, d)
                x, y = points[:, 0][None, :, None], points[:, 1][None, :, None]  # (1, n_points, 1)
            
            ut = -(mu * np.sin(np.pi * m * x) * np.sin(np.pi * m * y) 
                   * np.exp(-2 * m * m * np.pi * np.pi * t) / np.sqrt(m) / d).sum(-1)
            return ut
    
    class Random:
        """Random data generators for heat equation"""
        
        @staticmethod
        def initial_condition(points: np.ndarray, 
                            low: float = 0.0, 
                            high: float = 1.0) -> np.ndarray:
            """
            Generate random initial condition for heat equation.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                low: float
                    Lower bound of uniform distribution (default: 0.0)
                high: float
                    Upper bound of uniform distribution (default: 1.0)
                    
            Returns:
            --------
                u0: np.ndarray (n_points,)
                    Random initial condition
            """
            return np.random.uniform(low=low, high=high, size=(points.shape[0],))
        
        @staticmethod
        def solution(mesh, 
                    u0: np.ndarray, 
                    D: float = 1.0, 
                    T: float = 1.0, 
                    dt: float = 0.01, 
                    recording: bool = False, 
                    verbose: bool = True) -> Union[np.ndarray, List[np.ndarray]]:
            """
            Solve heat equation using finite element method with implicit Euler.
            
            Solves: ∂u/∂t = D² * Δu
            Using: (M + dt*D²*A) u^{n+1} = M u^n
            
            Parameters:
            -----------
                mesh: meshio.Mesh
                    The computational mesh with boundary conditions.
                    Must have 'boundary_mask' and 'boundary_value' in point_data.
                u0: np.ndarray (n_points,)
                    Initial condition values at mesh nodes
                D: float
                    Diffusion coefficient (default: 1.0)
                T: float
                    Final time (default: 1.0)
                dt: float
                    Time step size (default: 0.01)
                recording: bool
                    If True, return solution at all time steps (default: False)
                verbose: bool
                    If True, show progress bar (default: True)
                    
            Returns:
            --------
                If recording=True:
                    List[np.ndarray]: Solution at each time step [u0, u1, ..., uN]
                Else:
                    np.ndarray (n_points,): Solution at final time T
            """
            # Validate mesh data
            assert "boundary_mask" in mesh.point_data, "Mesh must have boundary_mask in point_data"
            assert "boundary_value" in mesh.point_data, "Mesh must have boundary_value in point_data"
            
            points = mesh.points
            
            # Determine element type and setup quadrature
            if "triangle" in mesh.cells_dict:
                elements = mesh.cells_dict['triangle']
                qpoints = tri_gauss_points(ngp=1)
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = tri3(xi, eta, mesh_points_2d)
            elif "quad" in mesh.cells_dict:
                elements = mesh.cells_dict['quad']
                qpoints = quad_gauss_points(ngp=4)
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = quad4(xi, eta, mesh_points_2d)
            else:
                raise NotImplementedError("Mesh should have triangle or quadrilateral cells")
            
            boundary_mask = mesh.point_data['boundary_mask']
            boundary_value = mesh.point_data['boundary_value']
            num_points = points.shape[0]
            num_elements, num_basis = elements.shape
            
            # Create element-to-edge mapping for matrix assembly
            elem_u, elem_v = [], []
            for i in range(num_basis):
                for j in range(num_basis):
                    elem_u.append(elements[:, i])
                    elem_v.append(elements[:, j])
            elem_u = np.stack(elem_u, -1).ravel()
            elem_v = np.stack(elem_v, -1).ravel()
            
            # Remove duplicated edges
            tmp = scipy.sparse.coo_matrix(
                (np.ones_like(elem_u), (elem_u, elem_v)),
                shape=(num_points, num_points)
            ).tocsr().tocoo()
            edge_u, edge_v = tmp.row, tmp.col
            num_edges = len(edge_u)
            
            eids_csr = scipy.sparse.coo_matrix(
                (np.arange(num_edges), (edge_u, edge_v)),
                shape=(num_points, num_points)
            ).tocsr()
            
            elem_eids = np.array(eids_csr[elem_u, elem_v]).ravel()
            ele2msh_edge = scipy.sparse.coo_matrix(
                (np.ones_like(elem_eids), (elem_eids, np.arange(num_elements * num_basis * num_basis))),
                shape=(num_edges, num_elements * num_basis * num_basis)
            ).astype(np.float64)
            
            # Compute Jacobian weighted quadrature weights
            JxW = jac_det * quadrature_weight  # [num_elements, num_quadrature_points]
            
            # Assemble mass matrix M and stiffness matrix A
            M_elem = np.einsum("qi,qj,eq->eqij", shape_val, shape_val, JxW)
            A_elem = np.einsum("eqib,eqjb,eq->eqij", shape_grad, shape_grad, JxW)
            
            # System matrix: K = M + dt * D² * A
            K_elem = M_elem + dt * D * D * A_elem
            K_elem = K_elem.sum(1)  # [num_elements, num_basis, num_basis]
            K = ele2msh_edge @ K_elem.ravel()  # [num_edges]
            
            M_elem = M_elem.sum(1)
            M = ele2msh_edge @ M_elem.ravel()
            M_global = scipy.sparse.coo_matrix(
                (M, (edge_u, edge_v)),
                shape=(num_points, num_points)
            ).tocsr()
            
            # Apply boundary conditions using static condensation
            is_inner_node = ~boundary_mask
            is_outer_node = boundary_mask
            is_inner_u = is_inner_node[edge_u]
            is_inner_v = is_inner_node[edge_v]
            is_inner_edge = is_inner_u & is_inner_v
            is_ou2in_edge = is_inner_u & is_outer_node[edge_v]
            
            n_inner_nodes = is_inner_node.sum()
            n_outer_nodes = is_outer_node.sum()
            
            # Create local node IDs
            local_nids = np.full((num_points,), -1, dtype=np.int64)
            local_nids[is_inner_node] = np.arange(n_inner_nodes)
            local_nids[is_outer_node] = np.arange(n_outer_nodes)
            
            # Extract matrix blocks
            K_inner = scipy.sparse.coo_matrix(
                (K[is_inner_edge], (local_nids[edge_u[is_inner_edge]], local_nids[edge_v[is_inner_edge]])),
                shape=(n_inner_nodes, n_inner_nodes)
            ).tocsr()
            
            K_ou2in = scipy.sparse.coo_matrix(
                (K[is_ou2in_edge], (local_nids[edge_u[is_ou2in_edge]], local_nids[edge_v[is_ou2in_edge]])),
                shape=(n_inner_nodes, n_outer_nodes)
            ).tocsr()
            
            # Initialize solution
            u0[is_outer_node] = boundary_value[is_outer_node]
            
            if recording:
                Us = [u0.copy()]
            else:
                U = u0.copy()
            
            # Time stepping
            if recording and verbose:
                iterator = tqdm(np.arange(dt, T + dt / 10, dt), desc="Heat equation")
            else:
                iterator = np.arange(dt, T + dt / 10, dt)
            
            for t_ in iterator:
                if recording:
                    F = Us[-1]
                else:
                    F = U
                
                # RHS: M * u^n
                F = M_global @ F
                
                # Apply boundary conditions
                F = F[is_inner_node] - K_ou2in @ boundary_value[is_outer_node]
                
                # Solve linear system
                u = scipy.sparse.linalg.spsolve(K_inner, F)
                
                if recording:
                    U_new = np.zeros((num_points,))
                    U_new[is_outer_node] = boundary_value[is_outer_node]
                    U_new[is_inner_node] = u
                    Us.append(U_new)
                else:
                    U[is_outer_node] = boundary_value[is_outer_node]
                    U[is_inner_node] = u
            
            if recording:
                return Us
            else:
                return U

