"""
Mathematical utility functions for TensorGalerkin
"""

import numpy as np
import torch
from typing import Tuple


def manual_seed(seed):
    """Set random seed for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)


def apply_zero_boundary(U, mesh_boundary_mask):
    """Apply zero boundary conditions"""
    if isinstance(U, torch.Tensor):
        U_constraint = torch.zeros_like(U, device=U.device)
    elif isinstance(U, np.ndarray):
        U_constraint = np.zeros_like(U)
        if isinstance(mesh_boundary_mask, torch.Tensor):
            mesh_boundary_mask = mesh_boundary_mask.cpu().numpy()

    if len(U.shape) == 1:
        U_constraint[~mesh_boundary_mask] += U[~mesh_boundary_mask]
    else:
        U_constraint[:,~mesh_boundary_mask, ...] += U[:,~mesh_boundary_mask, ...]
    return U_constraint


def apply_const_boundary(U, mesh_boundary_mask, const):
    """Apply constant boundary conditions"""
    if isinstance(U, torch.Tensor):
        U_constraint = torch.zeros_like(U, device=U.device)
    elif isinstance(U, np.ndarray):
        U_constraint = np.zeros_like(U)
        if isinstance(mesh_boundary_mask, torch.Tensor):
            mesh_boundary_mask = mesh_boundary_mask.cpu().numpy()

    if isinstance(const, torch.Tensor):
        const = const.to(U.device)

    if len(U.shape) == 1:
        U_constraint[mesh_boundary_mask] += const
        U_constraint[~mesh_boundary_mask] += U[~mesh_boundary_mask]
    else:
        U_constraint[:, mesh_boundary_mask, ...] += const
        U_constraint[:, ~mesh_boundary_mask, ...] += U[:, ~mesh_boundary_mask, ...]
    return U_constraint


def apply_dirichlet_boundary(U, mesh_boundary_mask, boundary_value):
    """Apply Dirichlet boundary conditions"""
    if isinstance(U, torch.Tensor):
        U_constraint = U.clone()
    elif isinstance(U, np.ndarray):
        U_constraint = U.copy()
        
    if len(U.shape) == 1:
        U_constraint[mesh_boundary_mask] = boundary_value
    else:
        if isinstance(boundary_value, (int, float)):
            U_constraint[:, mesh_boundary_mask, ...] = boundary_value
        else:
            U_constraint[:, mesh_boundary_mask, ...] = boundary_value[None, :, ...]
    return U_constraint


class CSRSpMV(torch.autograd.Function):
    """
    CSR Sparse Matrix-Vector multiplication with autograd support.
    
    Supports gradient computation through sparse matrix operations.
    """
    
    @staticmethod
    def forward(ctx, A: torch.Tensor, x: torch.Tensor, 
                src: torch.Tensor, dst: torch.Tensor, edata: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: b = A @ x
        
        Parameters:
        -----------
            A: torch.sparse_csr_tensor
                Sparse CSR matrix
            x: torch.Tensor [n_cols]
                Input vector
            src: torch.Tensor
                Source indices (row indices)
            dst: torch.Tensor
                Destination indices (column indices)
            edata: torch.Tensor
                Edge/entry data values
                
        Returns:
        --------
            b: torch.Tensor [n_rows]
                Result vector
        """
        b = torch.mm(A, x[:, None])[:, 0]
        ctx.save_for_backward(x, src, dst)
        return b
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[None, None, None, None, torch.Tensor]:
        """
        Backward pass: compute gradients w.r.t. edata
        """
        x, src, dst = ctx.saved_tensors
        grad_edata = grad_output[src] * x[dst]
        return None, None, None, None, grad_edata


class SparseMatrix:
    """
    Sparse matrix wrapper with autograd support for matrix-vector multiplication.
    
    Uses CSR format internally for efficient operations.
    """
    
    def __init__(self, src: torch.Tensor, dst: torch.Tensor, 
                 edata: torch.Tensor, shape: Tuple[int, int], cache: bool = True):
        """
        Initialize sparse matrix from COO-like format.
        
        Parameters:
        -----------
            src: torch.Tensor
                Source (row) indices
            dst: torch.Tensor
                Destination (column) indices
            edata: torch.Tensor
                Entry values
            shape: Tuple[int, int]
                Matrix shape (n_rows, n_cols)
            cache: bool
                Whether to cache the CSR tensor
        """
        assert src.shape == dst.shape == edata.shape, \
            f"src.shape: {src.shape}, dst.shape: {dst.shape}, edata.shape: {edata.shape}"
        
        # Sort by source index for CSR construction
        index = torch.argsort(src)
        src = src[index]
        dst = dst[index]
        edata = edata[index]
        
        num_rows, num_cols = shape
        indptr = torch.zeros(num_rows + 1, dtype=torch.long, device=edata.device)
        indptr[1:] = torch.cumsum(torch.bincount(src, minlength=num_rows), 0)

        self.src = src
        self.dst = dst
        self.indptr = indptr
        self.indices = dst
        self.edata = edata
        self.shape = shape
        
        if cache:
            self.A = torch.sparse_csr_tensor(indptr, self.indices, edata, shape)
        else:
            self.A = None

    def mv(self, vector: torch.Tensor) -> torch.Tensor:
        """
        Matrix-vector multiplication with autograd support.
        
        Parameters:
        -----------
            vector: torch.Tensor [n_cols]
                Input vector
                
        Returns:
        --------
            result: torch.Tensor [n_rows]
                Result of A @ vector
        """
        if self.A is None:
            A = torch.sparse_csr_tensor(self.indptr, self.indices, self.edata, self.shape)
        else:
            A = self.A
        return CSRSpMV.apply(A, vector, self.src, self.dst, self.edata)


def element_collect(edata: torch.Tensor, elements: torch.Tensor, n_nodes: int) -> torch.Tensor:
    """
    Collect element-wise data to node-wise data using assembly.
    
    This is the finite element assembly operation that gathers contributions
    from each element to the global node vector.
    
    Parameters:
    -----------
        edata: torch.Tensor [n_elements, n_basis]
            Element-wise data (one value per basis function per element)
        elements: torch.Tensor [n_elements, n_basis]
            Element connectivity (node indices for each element)
        n_nodes: int
            Total number of nodes in the mesh
            
    Returns:
    --------
        ndata: torch.Tensor [n_nodes]
            Node-wise data after assembly
    """
    assert edata.shape[0] == elements.shape[0], \
        f"edata and elements must have same number of elements, got {edata.shape[0]} and {elements.shape[0]}"
    assert edata.shape[1] == elements.shape[1], \
        f"edata and elements must have same n_basis, got {edata.shape[1]} and {elements.shape[1]}"
    
    n_elements = elements.shape[0]
    n_basis = elements.shape[1]
    shape = (n_nodes, n_elements)
    
    # Create sparse assembly matrix
    ele2msh = SparseMatrix(
        elements.flatten(),
        torch.arange(0, n_elements, device=edata.device).repeat_interleave(n_basis),
        edata.flatten(),
        shape
    )
    
    # Assemble by multiplying with ones vector
    ndata = ele2msh.mv(torch.ones([n_elements], device=edata.device, dtype=edata.dtype))
    
    return ndata