"""
Finite element shape functions for TensorGalerkin
"""

import torch 
import numpy as np
from .tensor_api import API


def get_tri3_jacobian(nodes):
    """
    Compute Jacobian matrix for 3-node triangular elements
    
    Parameters:
    -----------
        nodes: torch.Tensor (n_element, n_basis, 2)
            Node coordinates for each element
            
    Returns:
    --------
        jac: torch.Tensor (n_element, n_gaussian, 2, 2)
            Jacobian matrix at each Gauss point
    """
    API.detect_backend(nodes)
    assert API.dim(nodes) == 3 and nodes.shape[-1] == 2, \
        f"nodes should have shape (batch_size, n_basis, 2), but got {nodes.shape}"
    
    n_basis = 3
    shape_grad = API.zeros((n_basis, 2), device=API.device(nodes))
    
    # Shape function gradients in parent coordinates
    shape_grad[0, 0] = -1.0
    shape_grad[0, 1] = -1.0
    shape_grad[1, 0] = 1.0
    shape_grad[1, 1] = 0.0
    shape_grad[2, 0] = 0.0
    shape_grad[2, 1] = 1.0
    
    jac = API.einsum("bhj,hi->bij", nodes, shape_grad)
    return jac


def tri3(xi, eta, nodes, return_jacobian=False):
    """
    Calculate shape functions and gradients for 3-node triangular elements
    
    Parameters:
    -----------
        xi: torch.Tensor or np.ndarray (n_gauss,)
            Xi coordinates of Gauss points
        eta: torch.Tensor or np.ndarray (n_gauss,)
            Eta coordinates of Gauss points  
        nodes: torch.Tensor or np.ndarray (n_element, n_basis, 2)
            Node coordinates for each element
        return_jacobian: bool
            Whether to return Jacobian matrices
            
    Returns:
    --------
        shape_val: torch.Tensor (n_gaussian, n_basis)
            Shape function values at Gauss points
        shape_grad: torch.Tensor (n_element, n_gaussian, n_basis, 2)
            Shape function gradients in global coordinates
        jacdet: torch.Tensor (n_element, n_gaussian)
            Jacobian determinant at each Gauss point
        jac: torch.Tensor (n_element, n_gaussian, 2, 2) [optional]
            Jacobian matrices at each Gauss point
    """
    API.detect_backend(xi)
    
    assert len(xi) == len(eta), f"xi and eta should have the same length, but got {len(xi)} and {len(eta)}"
    assert API.dim(xi) == 1 and API.dim(eta) == 1, f"xi and eta should be 1D tensor, but got {xi.dim()} and {eta.dim()}"
    assert API.dim(nodes) == 3 and nodes.shape[-1] == 2, \
        f"nodes should have shape (batch_size, n_basis, 2), but got {nodes.shape}"
    
    n_gaussian = xi.shape[0]
    n_basis = 3
    
    # Shape function values
    shape_val = API.zeros((n_gaussian, n_basis), device=API.device(nodes))
    shape_val[:, 0] = 1.0 - xi - eta
    shape_val[:, 1] = xi
    shape_val[:, 2] = eta

    # Shape function gradients in parent coordinates
    shape_grad = API.zeros((n_gaussian, n_basis, 2), device=API.device(nodes))
    shape_grad[:, 0, 0] = -1.0
    shape_grad[:, 0, 1] = -1.0
    shape_grad[:, 1, 0] = 1.0
    shape_grad[:, 1, 1] = 0.0
    shape_grad[:, 2, 0] = 0.0
    shape_grad[:, 2, 1] = 1.0

    # Compute Jacobian
    jac = API.einsum("bhj,ghi->bgij", nodes, shape_grad)
    jacdet = API.det(jac)
    jacdet_abs = API.abs(jacdet)

    # Transform gradients to global coordinates
    ijac = API.inv(jac)
    shape_grad = API.einsum("gbi,ngji->ngbj", shape_grad, ijac)

    if return_jacobian:
        return shape_val, shape_grad, jacdet_abs, jac
    else:
        return shape_val, shape_grad, jacdet_abs


def quad4(xi, eta, nodes, return_jacobian=False):
    """
    Calculate shape functions and gradients for 4-node quadrilateral elements
    
    Parameters:
    -----------
        xi: torch.Tensor or np.ndarray (n_gauss,)
            Xi coordinates of Gauss points
        eta: torch.Tensor or np.ndarray (n_gauss,)
            Eta coordinates of Gauss points
        nodes: torch.Tensor or np.ndarray (n_element, n_basis, 2)
            Node coordinates for each element
        return_jacobian: bool
            Whether to return Jacobian matrices
            
    Returns:
    --------
        shape_val: torch.Tensor (n_gaussian, n_basis)
            Shape function values at Gauss points
        shape_grad: torch.Tensor (n_element, n_gaussian, n_basis, 2)
            Shape function gradients in global coordinates
        jacdet: torch.Tensor (n_element, n_gaussian)
            Jacobian determinant at each Gauss point
        jac: torch.Tensor (n_element, n_gaussian, 2, 2) [optional]
            Jacobian matrices at each Gauss point
    """
    API.detect_backend(xi)
    
    assert len(xi) == len(eta), f"xi and eta should have the same length, but got {len(xi)} and {len(eta)}"
    assert API.dim(xi) == 1 and API.dim(eta) == 1, f"xi and eta should be 1D tensor, but got {xi.dim()} and {eta.dim()}"
    assert API.dim(nodes) == 3 and nodes.shape[-1] == 2, \
        f"nodes should have shape (batch_size, n_basis, 2), but got {nodes.shape}"
    
    n_gaussian = xi.shape[0]
    n_basis = 4
    
    # Shape function values (bilinear)
    shape_val = API.zeros((n_gaussian, n_basis), device=API.device(nodes))
    shape_val[:, 0] = (1.0 - xi) * (1.0 - eta) / 4.0  
    shape_val[:, 1] = (1.0 + xi) * (1.0 - eta) / 4.0
    shape_val[:, 2] = (1.0 + xi) * (1.0 + eta) / 4.0
    shape_val[:, 3] = (1.0 - xi) * (1.0 + eta) / 4.0

    # Shape function gradients in parent coordinates
    shape_grad = API.zeros((n_gaussian, n_basis, 2), device=API.device(nodes))
    shape_grad[:, 0, 0] = (eta - 1.0) / 4.0  # dN1/dxi
    shape_grad[:, 0, 1] = (xi - 1.0) / 4.0   # dN1/deta
    shape_grad[:, 1, 0] = (1.0 - eta) / 4.0  # dN2/dxi
    shape_grad[:, 1, 1] = -(1.0 + xi) / 4.0  # dN2/deta
    shape_grad[:, 2, 0] = (1.0 + eta) / 4.0  # dN3/dxi
    shape_grad[:, 2, 1] = (1.0 + xi) / 4.0   # dN3/deta
    shape_grad[:, 3, 0] = -(1.0 + eta) / 4.0 # dN4/dxi
    shape_grad[:, 3, 1] = (1.0 - xi) / 4.0   # dN4/deta

    # Compute Jacobian
    jac = API.einsum("bhj,ghi->bgij", nodes, shape_grad)
    jacdet = API.det(jac)
    jacdet_abs = API.abs(jacdet)

    # Transform gradients to global coordinates
    ijac = API.inv(jac)
    shape_grad = API.einsum("gbi,ngji->ngbj", shape_grad, ijac)

    if return_jacobian:
        return shape_val, shape_grad, jacdet_abs, jac
    else:
        return shape_val, shape_grad, jacdet_abs


def tri6(xi, eta, nodes, return_jacobian=False):
    """
    Calculate shape functions and gradients for 6-node triangular elements
    
    The 6-node triangle is a quadratic element with nodes at vertices and
    edge midpoints.
    
    Parameters:
    -----------
        xi: torch.Tensor or np.ndarray (n_gauss,)
            Xi coordinates of Gauss points
        eta: torch.Tensor or np.ndarray (n_gauss,)
            Eta coordinates of Gauss points
        nodes: torch.Tensor or np.ndarray (n_element, n_basis, 2)
            Node coordinates for each element (6 nodes per element)
        return_jacobian: bool
            Whether to return Jacobian matrices
            
    Returns:
    --------
        shape_val: torch.Tensor (n_gaussian, n_basis)
            Shape function values at Gauss points
        shape_grad: torch.Tensor (n_element, n_gaussian, n_basis, 2)
            Shape function gradients in global coordinates
        jacdet: torch.Tensor (n_element, n_gaussian)
            Jacobian determinant at each Gauss point
        jac: torch.Tensor (n_element, n_gaussian, 2, 2) [optional]
            Jacobian matrices at each Gauss point
    """
    API.detect_backend(xi)
    
    assert len(xi) == len(eta), f"xi and eta should have the same length, but got {len(xi)} and {len(eta)}"
    assert API.dim(xi) == 1 and API.dim(eta) == 1, f"xi and eta should be 1D tensor, but got {API.dim(xi)} and {API.dim(eta)}"
    assert API.dim(nodes) == 3 and nodes.shape[-1] == 2, \
        f"nodes should have shape (batch_size, n_basis, 2), but got {nodes.shape}"
    
    n_gaussian = xi.shape[0]
    n_basis = 6
    
    # Shape function values (quadratic)
    shape_val = API.zeros((n_gaussian, n_basis), device=API.device(nodes))
    shape_val[:, 0] = 2 * xi**2 + 2 * eta**2 + 4 * xi * eta - 3 * xi - 3 * eta + 1
    shape_val[:, 1] = 2 * xi**2 - xi
    shape_val[:, 2] = 2 * eta**2 - eta
    shape_val[:, 3] = -4 * xi * (1 - xi - eta)
    shape_val[:, 4] = 4 * xi * eta
    shape_val[:, 5] = -4 * eta * (1 - xi - eta)

    # Shape function gradients in parent coordinates
    shape_grad = API.zeros((n_gaussian, n_basis, 2), device=API.device(nodes))
    shape_grad[:, 0, 0] = 4 * xi + 4 * eta - 3
    shape_grad[:, 0, 1] = 4 * eta + 4 * xi - 3

    shape_grad[:, 1, 0] = 4 * xi - 1
    shape_grad[:, 1, 1] = 0

    shape_grad[:, 2, 0] = 0
    shape_grad[:, 2, 1] = 4 * eta - 1

    shape_grad[:, 3, 0] = -4 * (1 - 2 * xi - eta)
    shape_grad[:, 3, 1] = 4 * xi

    shape_grad[:, 4, 0] = 4 * eta
    shape_grad[:, 4, 1] = 4 * xi

    shape_grad[:, 5, 0] = -4 * eta
    shape_grad[:, 5, 1] = -4 * (1 - xi - 2 * eta)

    # Compute Jacobian
    jac = API.einsum("bhi,ghj->bgij", nodes, shape_grad)
    jacdet = API.det(jac)
    jacdet_abs = API.abs(jacdet)

    # Transform gradients to global coordinates
    ijac = API.inv(jac)
    shape_grad = shape_grad @ ijac.mT

    if return_jacobian:
        return shape_val, shape_grad, jacdet_abs, jac
    else:
        return shape_val, shape_grad, jacdet_abs


def quad9(xi, eta, nodes, return_jacobian=False):
    """
    Calculate shape functions and gradients for 9-node quadrilateral elements
    
    The 9-node quadrilateral is a biquadratic element (serendipity element).
    
    Parameters:
    -----------
        xi: torch.Tensor or np.ndarray (n_gauss,)
            Xi coordinates of Gauss points
        eta: torch.Tensor or np.ndarray (n_gauss,)
            Eta coordinates of Gauss points
        nodes: torch.Tensor or np.ndarray (n_element, n_basis, 2)
            Node coordinates for each element (9 nodes per element)
        return_jacobian: bool
            Whether to return Jacobian matrices
            
    Returns:
    --------
        shape_val: torch.Tensor (n_gaussian, n_basis)
            Shape function values at Gauss points
        shape_grad: torch.Tensor (n_element, n_gaussian, n_basis, 2)
            Shape function gradients in global coordinates
        jacdet: torch.Tensor (n_element, n_gaussian)
            Jacobian determinant at each Gauss point
        jac: torch.Tensor (n_element, n_gaussian, 2, 2) [optional]
            Jacobian matrices at each Gauss point
    
    Note:
    -----
        This is a stub implementation. Full implementation pending.
    """
    raise NotImplementedError("quad9 shape functions not yet implemented")