"""
Noise generation utilities for Unbiased Zoo.
"""

import numpy as np

def unit_vector(vector):
    """Returns the unit vector of the vector.
    
    Args:
        vector (numpy.ndarray): Input vector.
        
    Returns:
        numpy.ndarray: Unit vector.
    """
    return vector / np.linalg.norm(vector)

def angle_between(v1, v2):
    """Returns the angle in radians between vectors 'v1' and 'v2'.
    
    Args:
        v1 (numpy.ndarray): First vector.
        v2 (numpy.ndarray): Second vector.
        
    Returns:
        float: Angle in radians.
    
    Examples:
        >>> angle_between((1, 0, 0), (0, 1, 0))
        1.5707963267948966
        >>> angle_between((1, 0, 0), (1, 0, 0))
        0.0
        >>> angle_between((1, 0, 0), (-1, 0, 0))
        3.141592653589793
    """
    v1_u = unit_vector(v1)
    v2_u = unit_vector(v2)
    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))

def sample_spherical(npoints=1, ndim=3):
    """Sample points uniformly from the surface of a sphere.
    
    Args:
        npoints (int, optional): Number of points to sample. Defaults to 1.
        ndim (int, optional): Dimension of the sphere. Defaults to 3.
        
    Returns:
        numpy.ndarray: Sampled points.
    """
    vec = np.random.randn(ndim, npoints)
    vec /= np.linalg.norm(vec, axis=0)
    return vec * np.sqrt(ndim)

def get_coordinate_noise(input):
    """Generate coordinate-wise noise.
    
    Args:
        input (numpy.ndarray): Input vector.
        
    Returns:
        numpy.ndarray: Noise vector.
    """
    v = np.zeros_like(input)
    ind = np.random.choice(len(input))
    v[ind] = 1.0
    return v

def get_spherical_noise(input):
    """Generate spherical noise.
    
    Args:
        input (numpy.ndarray): Input vector.
        
    Returns:
        numpy.ndarray: Noise vector.
    """
    ndim = len(input)
    return sample_spherical(ndim=ndim).flatten()

def get_gaussian_noise(input):
    """Generate Gaussian noise.
    
    Args:
        input (numpy.ndarray): Input vector.
        
    Returns:
        numpy.ndarray: Noise vector.
    """
    return np.random.normal(scale=1.0, size=input.shape)

def get_bernoulli_noise(input):
    """Generate Bernoulli noise.
    
    Args:
        input (numpy.ndarray): Input vector.
        
    Returns:
        numpy.ndarray: Noise vector.
    """
    return -1 + 2*np.random.binomial(n=1, p=0.5, size=input.shape)

def proj_multiple(num_vectors, a):
    """Project multiple vectors onto the orthogonal complement of a.
    
    Args:
        num_vectors (int): Number of vectors to project.
        a (numpy.ndarray): Vector to project onto.
        
    Returns:
        numpy.ndarray: Projected vectors.
    """
    n = len(a)
    a_norm_square = np.dot(a, a)
    P = np.eye(n) - np.outer(a, a) / a_norm_square
    drift = a / np.sqrt(a_norm_square)

    X = np.random.normal(size=(num_vectors, n))
    xi = np.random.choice([-1, 1], size=num_vectors)

    projected = X @ P.T + xi[:, np.newaxis] * drift
    return projected

def get_grad_noise(input, grad):
    """Generate noise based on the gradient.
    
    Args:
        input (numpy.ndarray): Input vector.
        grad (numpy.ndarray): Gradient vector.
        
    Returns:
        numpy.ndarray: Noise vector.
    """
    return proj_multiple(1, grad).flatten()

def get_optimal(input, a):
    """Generate optimal noise for gradient estimation.
    
    Args:
        input (numpy.ndarray): Input vector.
        a (numpy.ndarray): Gradient vector.
        
    Returns:
        numpy.ndarray: Noise vector.
    """
    n = len(a)
    # generate a vector that is orthogonal to a
    p = np.random.normal(scale=1.0, size=a.shape)
    w = p - p.dot(a)/a.dot(a) * a
    w = np.sqrt(a.dot(a))/ np.sqrt(w.dot(w)) * w
    v = np.sqrt(1-1/n) * w + 1/np.sqrt(n) * a
    v = v/np.sqrt(v.dot(v))
    return v * np.sqrt(n) * np.random.choice([-1, 1])

# Dictionary mapping noise types to generator functions
GENERATOR = {
    "coordinate": get_coordinate_noise,
    "uniform": get_spherical_noise,
    "gaussian": get_gaussian_noise,
    "bernoulli": get_bernoulli_noise
} 