"""
Base class for gradient estimators.
"""

import numpy as np
from abc import ABC, abstractmethod

class GradEstimator(ABC):
    """Base class for gradient estimators.
    
    This class defines the interface for gradient estimators.
    
    Attributes:
        zoo_batch_size (int): The batch size for gradient estimation.
        mu (float): The perturbation size.
    """
    
    def __init__(self, zoo_batch_size=16, mu=1e-6):
        """Initialize a gradient estimator.
        
        Args:
            zoo_batch_size (int, optional): The batch size for gradient estimation. Defaults to 16.
            mu (float, optional): The perturbation size. Defaults to 1e-6.
        """
        self.zoo_batch_size = zoo_batch_size
        self.mu = mu
        self._cost = 2 # cost is the number of function evaluations per gradient estimation
    
    @abstractmethod
    def generate_noise(self, x):
        """Generate noise for gradient estimation.
        
        Args:
            x (numpy.ndarray): The point at which to estimate the gradient.
            
        Returns:
            numpy.ndarray: The noise vector.
        """
        pass
    
    def estimate(self, f, x):
        """Estimate the gradient of a function at a point.
        
        Args:
            f (callable): The function to differentiate.
            x (numpy.ndarray): The point at which to estimate the gradient.
            
        Returns:
            numpy.ndarray: The estimated gradient.
        """
        grad = np.zeros_like(x)
        
        for i in range(self.zoo_batch_size):
            v = self.generate_noise(x)
            noisy_x = x + self.mu * v
            diff = f(noisy_x) - f(x)
            grad += diff * v
            
        return grad / (self.mu * self.zoo_batch_size) 