import numpy as np

def gram_schmidt(B):
    """
    Orthonormalizes two basis vectors using the Gram-Schmidt process.
    
    Parameters:
    B : ndarray of shape (N, 2, 5) - Two basis vectors for each sample
    
    Returns:
    B_orth : ndarray of shape (N, 2, 5) - Orthonormalized basis vectors
    """
    u1 = B[:, 0]  # First vector (N, 5)
    u1 /= np.linalg.norm(u1, axis=1, keepdims=True)  # Normalize

    u2 = B[:, 1] - np.einsum('ij,ij->i', B[:, 1], u1)[:, np.newaxis] * u1  # Remove projection on u1
    u2 /= np.linalg.norm(u2, axis=1, keepdims=True)  # Normalize

    return np.stack((u1, u2), axis=1)  # Stack orthonormalized vectors

def project_on_gradient_plane(X1, X2, grad_f):
    """
    Projects displacement vectors onto the 2D plane spanned by the function's gradients at X1 and X2.

    Parameters:
    X1 : ndarray of shape (N, 5)  - First set of points (each row is a 5D point)
    X2 : ndarray of shape (N, 5)  - Second set of points
    grad_f : function handle      - Function that returns the gradient at a given point (supports batch input)

    Returns:
    v_proj : ndarray of shape (N, 5) - Projected vectors onto the 2D gradient plane
    """
    G1 = grad_f(X1)  # Gradient at first point (N, 5)
    G2 = grad_f(X2)  # Gradient at second point (N, 5)

    # Use the two gradients as the basis for the 2D plane
    B = np.stack((G1, G2), axis=1)  # Shape: (N, 2, 5)

    # Orthonormalize the basis vectors
    B_orth = gram_schmidt(B)  # Shape: (N, 2, 5)

    # Compute displacement vectors
    V_ij = X2 - X1  # Shape: (N, 5)

    # Project displacement vectors onto the 2D plane
    coeffs = np.einsum('nij,nj->ni', B_orth, V_ij)  # Projection coefficients (N, 2)
    v_proj = np.einsum('ni,nij->nj', coeffs, B_orth)  # Reconstruct projected vectors (N, 5)

    return v_proj

def weight_probabilities(dy, dists, sample_uniformly=False):
        """Compute probability that a certain weight should be chosen as part of the network.
        This method computes all probabilities at once, without removing the new weights one by one.

        Args:
            dy: function difference
            dists: distance between the base points
            rng: random number generator

        Returns:
            probabilities: probabilities for the weights.
        """
        # compute the maximum over all changes in all y directions to sample good gradients for all outputs
        gradients = (np.max(np.abs(dy), axis=1, keepdims=True) / dists).ravel()

        if sample_uniformly or np.sum(gradients) < 1e-10:
            # When all gradients are small, avoind dividing by a small number
            # and default to uniform distribution.
            probabilities = np.ones_like(gradients) / len(gradients)
        else:
            probabilities = gradients / np.sum(gradients)

        return probabilities

