"""
Allen-Cahn equation data generators for TensorGalerkin

Solves the Allen-Cahn equation:
    ∂u/∂t = D * Δu + ε² * u * (1 - u²)  in Ω
    u = g                                on ∂Ω (Dirichlet boundary)
    u(x, 0) = u0                         (initial condition)

The Allen-Cahn equation is a nonlinear parabolic PDE that models phase 
separation in multi-component alloy systems.
"""

import numpy as np
import scipy.sparse
import scipy.sparse.linalg
from typing import Union, List
from tqdm import tqdm
import warnings

from ...discretization import tri3, quad4, tri_gauss_points, quad_gauss_points


class AllenCahnGen:
    """Data generators for Allen-Cahn equation problems"""
    
    class MultiAnalytical:
        """Multi-mode analytical initial conditions for Allen-Cahn equation"""
        
        @staticmethod
        def initial_condition(points: np.ndarray, 
                            a: np.ndarray, 
                            r: float = 0.5) -> np.ndarray:
            """
            Generate sinusoidal initial condition (same as wave equation).
            
            The initial condition is a superposition of sinusoidal modes:
            u0(x,y) = π/K² * Σ_{i,j} (a_{ij} * (i²+j²)^{-r} * sin(πix) * sin(πjy))
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                a: np.ndarray (K, K) or (N, K, K)
                    Coefficient matrix for the initial condition
                r: float
                    Decay exponent for frequency modes (default: 0.5)
                    
            Returns:
            --------
                u0: np.ndarray (n_points,) or (N, n_points)
                    Initial condition values
            """
            K = a.shape[-1]
            j, i = np.meshgrid(np.arange(1, K + 1), np.arange(1, K + 1))  # (K, K)
            
            if len(a.shape) == 2:
                a = a[None, :, :]  # (1, K, K)
                i, j = i[None, :, :], j[None, :, :]  # (1, K, K)
                x, y = points[:, 0][:, None, None], points[:, 1][:, None, None]  # (n_points, 1, 1)
            else:
                a = a[:, None, :, :]  # (N, 1, K, K)
                i, j = i[None, None, :, :], j[None, None, :, :]  # (1, 1, K, K)
                x, y = points[:, 0][None, :, None, None], points[:, 1][None, :, None, None]  # (1, n_points, 1, 1)
            
            u0 = (np.pi / K / K * (a * (i * i + j * j)**(-r) * 
                  np.sin(np.pi * i * x) * np.sin(np.pi * j * y))).sum((-2, -1))
            return u0
    
    class Random:
        """Random data generators for Allen-Cahn equation"""
        
        @staticmethod
        def initial_condition(points: np.ndarray, 
                            low: float = 0.0, 
                            high: float = 1.0) -> np.ndarray:
            """
            Generate random initial condition for Allen-Cahn equation.
            
            Parameters:
            -----------
                points: np.ndarray (n_points, 2)
                    Spatial coordinates
                low: float
                    Lower bound of uniform distribution (default: 0.0)
                high: float
                    Upper bound of uniform distribution (default: 1.0)
                    
            Returns:
            --------
                u0: np.ndarray (n_points,)
                    Random initial condition
            """
            return np.random.uniform(low=low, high=high, size=(points.shape[0],))
        
        @staticmethod
        def solution(mesh, 
                    u0: np.ndarray, 
                    D: float = 1.0, 
                    T: float = 1.0, 
                    dt: float = 0.01, 
                    epsilon: float = 22.0, 
                    recording: bool = False, 
                    verbose: bool = True) -> Union[np.ndarray, List[np.ndarray]]:
            """
            Solve Allen-Cahn equation using finite element method with Newton iteration.
            
            Solves: ∂u/∂t = D * Δu + ε² * u * (1 - u²)
            
            The equation is discretized in time using implicit Euler and the 
            nonlinear system is solved using Newton-Raphson iteration.
            
            Parameters:
            -----------
                mesh: meshio.Mesh
                    The computational mesh with boundary conditions.
                    Must have 'boundary_mask' and 'boundary_value' in point_data.
                u0: np.ndarray (n_points,)
                    Initial condition values at mesh nodes
                D: float
                    Diffusion coefficient (default: 1.0)
                T: float
                    Final time (default: 1.0)
                dt: float
                    Time step size (default: 0.01)
                epsilon: float
                    Interface parameter (default: 22.0)
                    Controls the sharpness of phase boundaries.
                recording: bool
                    If True, return solution at all time steps (default: False)
                verbose: bool
                    If True, show progress bar (default: True)
                    
            Returns:
            --------
                If recording=True:
                    List[np.ndarray]: Solution at each time step [u0, u1, ..., uN]
                Else:
                    np.ndarray (n_points,): Solution at final time T
            """
            # Validate mesh data
            assert "boundary_mask" in mesh.point_data, "Mesh must have boundary_mask in point_data"
            assert "boundary_value" in mesh.point_data, "Mesh must have boundary_value in point_data"
            
            # Define nonlinear functions
            # D(u) = 1 (constant diffusion)
            # f(u) = -ε² * u * (u² - 1) (reaction term)
            # df(u) = -ε² * (3u² - 1) (derivative of reaction term)
            def diffusion(x):
                return 1.0
            
            def d_diffusion(x):
                return 0.0
            
            def reaction(x):
                return -epsilon**2 * x * (x**2 - 1)
            
            def d_reaction(x):
                return -epsilon**2 * (3 * x**2 - 1)
            
            points = mesh.points
            
            # Determine element type and setup quadrature
            if "triangle" in mesh.cells_dict:
                elements = mesh.cells_dict['triangle']
                qpoints = tri_gauss_points(ngp=1)
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = tri3(xi, eta, mesh_points_2d)
            elif "quad" in mesh.cells_dict:
                elements = mesh.cells_dict['quad']
                qpoints = quad_gauss_points(ngp=4)
                quadrature_weight, xi, eta = qpoints[:, 0], qpoints[:, 1], qpoints[:, 2]
                mesh_points_2d = points[elements][:, :, :2]
                shape_val, shape_grad, jac_det = quad4(xi, eta, mesh_points_2d)
            else:
                raise NotImplementedError("Mesh should have triangle or quadrilateral cells")
            
            boundary_mask = mesh.point_data['boundary_mask']
            boundary_value = mesh.point_data['boundary_value']
            num_points = points.shape[0]
            num_elements, num_basis = elements.shape
            
            # Create element-to-node mapping
            ele2msh_node = scipy.sparse.coo_matrix(
                (np.ones([num_elements * num_basis]),
                 (elements.ravel(), np.arange(num_elements * num_basis))),
                shape=(num_points, num_elements * num_basis)
            ).tocsr()
            
            # Create element-to-edge mapping for matrix assembly
            elem_u, elem_v = [], []
            for i in range(num_basis):
                for j in range(num_basis):
                    elem_u.append(elements[:, i])
                    elem_v.append(elements[:, j])
            elem_u = np.stack(elem_u, -1).ravel()
            elem_v = np.stack(elem_v, -1).ravel()
            
            # Remove duplicated edges
            tmp = scipy.sparse.coo_matrix(
                (np.ones_like(elem_u), (elem_u, elem_v)),
                shape=(num_points, num_points)
            ).tocsr().tocoo()
            edge_u, edge_v = tmp.row, tmp.col
            num_edges = len(edge_u)
            
            eids_csr = scipy.sparse.coo_matrix(
                (np.arange(num_edges), (edge_u, edge_v)),
                shape=(num_points, num_points)
            ).tocsr()
            
            elem_eids = np.array(eids_csr[elem_u, elem_v]).ravel()
            ele2msh_edge = scipy.sparse.coo_matrix(
                (np.ones_like(elem_eids), (elem_eids, np.arange(num_elements * num_basis * num_basis))),
                shape=(num_edges, num_elements * num_basis * num_basis)
            ).astype(np.float64)
            
            # Compute Jacobian weighted quadrature weights
            JxW = jac_det * quadrature_weight  # [num_elements, num_quadrature_points]
            
            def apply_zero_boundary(U, mask):
                """Apply zero Dirichlet BC on boundary nodes."""
                U_constraint = np.zeros_like(U)
                if len(U.shape) == 1:
                    U_constraint[~mask] = U[~mask]
                else:
                    U_constraint[:, ~mask, ...] = U[:, ~mask, ...]
                return U_constraint
            
            def assemble_residual(U_t1, U_t2):
                """
                Assemble the residual vector R.
                
                R^I = ∫ [(u^n - u^{n-1})/dt * N^I + D*∇u^n·∇N^I - f(u^n)*N^I] dx
                """
                U_t1 = apply_zero_boundary(U_t1, boundary_mask)
                U_t2 = apply_zero_boundary(U_t2, boundary_mask)
                
                elemU_t1 = U_t1[elements]  # [n_element, n_basis]
                elemU_t2 = U_t2[elements]
                
                # Interpolate to quadrature points
                phi_t1 = np.einsum('gb,eb->eg', shape_val, elemU_t1)  # [n_element, n_quad]
                phi_t2 = np.einsum('gb,eb->eg', shape_val, elemU_t2)
                gradphi_t2 = np.einsum('egbd,eb->egd', shape_grad, elemU_t2)  # [n_element, n_quad, 2]
                
                phidot = (phi_t2 - phi_t1) / dt  # [n_element, n_quad]
                
                # Compute residual integrand
                integral = (np.einsum("eg,gb->egb", phidot, shape_val) 
                          + np.einsum("egd,egbd->egb", diffusion(phi_t2) * gradphi_t2, shape_grad) 
                          - np.einsum("eg,gb->egb", reaction(phi_t2), shape_val))
                
                integral = np.einsum("egb,eg->eb", integral, JxW)  # [n_element, n_basis]
                
                R = ele2msh_node @ integral.ravel()
                R = apply_zero_boundary(R, boundary_mask)
                return R
            
            def assemble_jacobian(U_t1, U_t2):
                """
                Assemble the Jacobian matrix K = -∂R/∂u.
                
                K^IJ = ∫ [N^I·N^J/dt + D*∇N^I·∇N^J - df(u^n)*N^I·N^J] dx
                """
                U_t1 = apply_zero_boundary(U_t1, boundary_mask)
                U_t2 = apply_zero_boundary(U_t2, boundary_mask)
                
                elemU_t2 = U_t2[elements]  # [n_element, n_basis]
                
                # Interpolate to quadrature points
                phi_t2 = np.einsum('gb,eb->eg', shape_val, elemU_t2)  # [n_element, n_quad]
                gradphi_t2 = np.einsum('egbd,eb->egd', shape_grad, elemU_t2)  # [n_element, n_quad, 2]
                
                # Mass matrix term
                dcdotdc = 1.0 / dt
                mul_uv = np.einsum('gi,gj->gij', shape_val, shape_val)  # [n_quad, n_basis, n_basis]
                dot_graduv = np.einsum('egid,egjd->egij', shape_grad, shape_grad)  # [n_element, n_quad, n_basis, n_basis]
                
                # Compute Jacobian integrand (note the negative sign: K = -∂R/∂u)
                integral = -1.0 * (
                    dcdotdc * mul_uv[None, ...]  # [1, n_quad, n_basis, n_basis]
                    + np.einsum("gi,egjd,egd->egij", d_diffusion(phi_t2) * shape_val, shape_grad, gradphi_t2)
                    + diffusion(phi_t2)[:, :, None, None] * dot_graduv  # broadcasting
                    - np.einsum("eg,gij->egij", d_reaction(phi_t2), mul_uv)
                )
                
                integral = np.einsum("egij,eg->eij", integral, JxW)  # [n_element, n_basis, n_basis]
                
                K = ele2msh_edge @ integral.ravel()
                return K
            
            # Apply boundary conditions using static condensation
            is_inner_node = ~boundary_mask
            is_outer_node = boundary_mask
            is_inner_u = is_inner_node[edge_u]
            is_inner_v = is_inner_node[edge_v]
            is_inner_edge = is_inner_u & is_inner_v
            is_ou2in_edge = is_inner_u & is_outer_node[edge_v]
            
            n_inner_nodes = is_inner_node.sum()
            n_outer_nodes = is_outer_node.sum()
            
            # Create local node IDs
            local_nids = np.full((num_points,), -1, dtype=np.int64)
            local_nids[is_inner_node] = np.arange(n_inner_nodes)
            local_nids[is_outer_node] = np.arange(n_outer_nodes)
            
            # Initialize solution
            u0[is_outer_node] = boundary_value[is_outer_node]
            
            if recording:
                Us = [u0.copy()]
            else:
                U = u0.copy()
            
            # Time stepping
            if recording and verbose:
                iterator = tqdm(np.arange(dt, T + dt / 10, dt), desc="Allen-Cahn")
            else:
                iterator = np.arange(dt, T + dt / 10, dt)
            
            uold = u0.copy()
            
            for t_ in iterator:
                u = uold.copy()
                converged = False
                
                # Newton iteration (max 10 iterations)
                for _ in range(10):
                    K = assemble_jacobian(uold, u)
                    R = assemble_residual(uold, u)
                    
                    # Extract inner node blocks
                    K_inner = scipy.sparse.coo_matrix(
                        (K[is_inner_edge], 
                         (local_nids[edge_u[is_inner_edge]], local_nids[edge_v[is_inner_edge]])),
                        shape=(n_inner_nodes, n_inner_nodes)
                    ).tocsr()
                    
                    K_ou2in = scipy.sparse.coo_matrix(
                        (K[is_ou2in_edge], 
                         (local_nids[edge_u[is_ou2in_edge]], local_nids[edge_v[is_ou2in_edge]])),
                        shape=(n_inner_nodes, n_outer_nodes)
                    ).tocsr()
                    
                    R_inner = R[is_inner_node]
                    R_outer = R[is_outer_node]
                    
                    # Solve Newton update
                    du_inner = scipy.sparse.linalg.spsolve(K_inner, R_inner - K_ou2in @ R_outer)
                    
                    du = np.zeros_like(u)
                    du[is_inner_node] = du_inner
                    u = u + du
                    
                    rnorm = np.linalg.norm(R)
                    if rnorm < 1e-10:
                        converged = True
                        break
                
                if not converged and rnorm > 1e-10:
                    warnings.warn(f"Newton solver did not converge at t={t_:.6f}, ||R||={rnorm:.2e}")
                
                uold = u
                
                if recording:
                    Us.append(u.copy())
                else:
                    U = u
            
            if recording:
                return Us
            else:
                return U

