"""
Gaussian quadrature rules for finite element integration
"""

import numpy as np


def tri_gauss_points(ngp: int = 3) -> np.ndarray:
    """
    Generate Gaussian quadrature points and weights for triangular elements
    
    Parameters:
    -----------
        ngp: int
            Number of Gauss points (1, 3, or 4)
            
    Returns:
    --------
        qpoints: np.ndarray
            Quadrature points and weights with shape (ngp, 3)
            Format: [weight, xi, eta] for each point
    """
    if ngp == 1:
        # 1-point rule (exact for linear polynomials)
        points = [[1./3, 1./3]]
        weights = [0.5]  
    elif ngp == 3:
        # 3-point rule (exact for quadratic polynomials)
        points = [[1./6, 1./6],
                  [2./3, 1./6],
                  [1./6, 2./3]]
        weights = [1./6, 1./6, 1./6]  
    elif ngp == 4:
        # 4-point rule (exact for cubic polynomials)
        points = [[1./3, 1./3],
                  [1./5, 1./5],
                  [3./5, 1./5],
                  [1./5, 3./5]]
        weights = [-9./32, 25./96, 25./96, 25./96]  
    else:
        raise ValueError(f"No data for {ngp}-point rule for triangles")
    
    combined = np.hstack((np.array(weights).reshape(-1, 1), np.array(points)))
    return combined


def quad_gauss_points(ngp: int = 4) -> np.ndarray:
    """
    Generate Gaussian quadrature points and weights for quadrilateral elements
    
    Parameters:
    -----------
        ngp: int
            Number of Gauss points per direction (1, 2, 3, 4, or 5)
            Total points will be ngp^2
            
    Returns:
    --------
        qpoints: np.ndarray
            Quadrature points and weights with shape (ngp^2, 3)
            Format: [weight, xi, eta] for each point
    """
    # 1D Gauss points and weights
    if ngp == 1:
        points_1d = [0.0]  
        weights_1d = [2.0]
    elif ngp == 2:
        points_1d = [-np.sqrt(1./3), np.sqrt(1./3)]
        weights_1d = [1.0, 1.0]
    elif ngp == 3:
        points_1d = [-np.sqrt(3./5), 0.0, np.sqrt(3./5)]
        weights_1d = [5./9, 8./9, 5./9]
    elif ngp == 4:
        a = np.sqrt((3.0 - 2.0*np.sqrt(6.0/5.0)) / 7.0)
        b = np.sqrt((3.0 + 2.0*np.sqrt(6.0/5.0)) / 7.0)
        points_1d = [-b, -a, a, b]
        
        w1 = (18.0 + np.sqrt(30.0)) / 36.0
        w2 = (18.0 - np.sqrt(30.0)) / 36.0
        weights_1d = [w2, w1, w1, w2]
    elif ngp == 5:
        a = np.sqrt(5.0 - 2.0*np.sqrt(10.0/7.0)) / 3.0
        b = np.sqrt(5.0 + 2.0*np.sqrt(10.0/7.0)) / 3.0
        points_1d = [-b, -a, 0.0, a, b]
        
        w1 = (322.0 - 13.0*np.sqrt(70.0)) / 900.0
        w2 = (322.0 + 13.0*np.sqrt(70.0)) / 900.0
        w3 = 128.0 / 225.0
        weights_1d = [w1, w2, w3, w2, w1]
    else:
        raise ValueError(f"No data for {ngp}-point rule for quadrilaterals")
    
    # Create 2D tensor product
    points_2d = []
    weights_2d = []
    
    for i in range(ngp):
        for j in range(ngp):
            xi = points_1d[i]
            eta = points_1d[j]
            weight = weights_1d[i] * weights_1d[j]
            
            points_2d.append([xi, eta])
            weights_2d.append(weight)
    
    # Combine into single array: [weight, xi, eta]
    combined = np.hstack((np.array(weights_2d).reshape(-1, 1), np.array(points_2d)))
    return combined