"""
Unbiased estimator implementation.
"""

import numpy as np
from scipy.special import zeta
from .base import GradEstimator

MAX_K = 100 
def muu(k, alpha=0.99):
    return alpha ** k + 1e-10
# def muu(k, alpha=0.99):
#     return 1/(k+1)**2


class UnbiasedEstimator(GradEstimator):
    """Unbiased estimator implementation.
    
    This estimator uses a special technique to achieve unbiased gradient estimation.
    
    Attributes:
        zoo_batch_size (int): The batch size for gradient estimation.
        mu (float): The perturbation size.
        a (float): The parameter for the Zipf distribution.
        P (int): The number of function evaluations per gradient estimation.
        alpha (float): The parameter for the perturbation size.
    """
    
    def __init__(self, P=4, zoo_batch_size=16, mu=1e-6, a=2.0, alpha=0.99):
        """Initialize an unbiased estimator.
        
        Args:
            zoo_batch_size (int, optional): The batch size for gradient estimation. Defaults to 16.
            mu (float, optional): The perturbation size. Defaults to 1e-6.
            a (float, optional): The parameter for the Zipf distribution. Defaults to 2.0.
        """
        super().__init__(zoo_batch_size, mu)
        self.a = a 
        assert P in [1,2,3,4], "P must be 1, 2, 3, or 4"
        self.P = P
        self._cost = P # cost is the number of function evaluations per gradient estimation
        self.alpha = alpha
    
    def generate_noise(self, x):
        """Generate uniform noise on the unit sphere.
        
        Args:
            x (numpy.ndarray): The point at which to estimate the gradient.
            
        Returns:
            numpy.ndarray: The noise vector.
        """
        ndim = len(x)
        vec = np.random.randn(ndim)
        vec /= np.linalg.norm(vec)
        return vec * np.sqrt(ndim)
        # return -1 + 2 * np.random.binomial(n=1, p=0.5, size=x.shape) 
    
    def P4_estimator(self, f, x):
        grad = np.zeros_like(x)
        delta = self.mu
        _a = self.a
        
        for i in range(self.zoo_batch_size):
            v = self.generate_noise(x)
            
            k = np.random.zipf(a=_a)
            k = min(k, MAX_K)
            
            # calculate a_1
            noisy_x = x + delta * v
            diff = f(noisy_x) - f(x)
            a_1 = diff / delta / zeta(_a)
            grad += a_1 * v
            
            # calculate a_k
            noisy_x = x + delta * v * muu(k-1)
            diff_k = f(noisy_x) - f(x)
            a_k = diff_k / delta / muu(k-1)
            
            noisy_x = x + delta * v * muu(k)
            diff_kk = f(noisy_x) - f(x)
            a_kk = diff_kk / delta / muu(k)
            
            # calculate derivative
            p_k = k**-_a/zeta(_a)
            dd = (a_kk - a_k) / p_k  
            grad += dd * v 
        return grad / self.zoo_batch_size 
    
    def P2_estimator(self, f, x):
        pass 
    def P1_estimator(self, f, x):
        pass 
    def P3_estimator(self, f, x):
        grad = np.zeros_like(x)
        delta = self.mu
        _a = self.a
        
        for i in range(self.zoo_batch_size):
            v = self.generate_noise(x)
            selection_variable = np.random.binomial(n=1, p=0.5)
            
            k = np.random.zipf(a=_a)
            k = min(k, MAX_K)
            
            # calculate a_1
            noisy_x = x + delta * v
            diff = f(noisy_x) - f(x)
            a_1 = diff / delta / zeta(_a)
            grad += a_1 * v * selection_variable
            
            # calculate a_k
            noisy_x = x + delta * v * muu(k-1)
            diff_k = f(noisy_x) - f(x)
            a_k = diff_k / delta / muu(k-1)
            
            noisy_x = x + delta * v * muu(k)
            diff_kk = f(noisy_x) - f(x)
            a_kk = diff_kk / delta / muu(k)
            
            # calculate derivative
            p_k = k**-_a/zeta(_a)
            dd = (a_kk - a_k) / p_k  
            grad += dd * v * (1 - selection_variable)
        return grad / self.zoo_batch_size * 2
    

    def estimate(self, f, x):
        if self.P == 4:
            return self.P4_estimator(f, x)
        elif self.P == 2:
            return self.P2_estimator(f, x)
        elif self.P == 1:
            return self.P1_estimator(f, x)
        elif self.P == 3:
            return self.P3_estimator(f, x)
        else:
            raise ValueError("P must be 1, 2, 3, or 4")
        
class Distribution:
    def __init__(self):
        pass 
    
    def sample(self):
        pass 
    
    def pdf(self, n):
        pass 

class ZipfDistribution(Distribution):
    def __init__(self, a=2.0):
        self.a = a 
    
    def sample(self):
        return np.random.zipf(a=self.a) 
    
    def pdf(self, n):
        return n**-self.a / zeta(self.a)

class GeometricDistribution(Distribution):
    def __init__(self, p=0.9):
        self.p = p 
    
    def sample(self):
        return np.random.geometric(p=self.p) 
    
    def pdf(self, n):
        return self.p * (1 - self.p)**(n - 1)
    


class UnbiasedEstimatorV2(GradEstimator): 
    def __init__(self, p, P=4, zoo_batch_size=16, mu=1e-6):
        """Initialize an unbiased estimator.
        
        Args:
            p: the probability density. 
        """
        super().__init__(zoo_batch_size, mu) 
        assert P in [1,2,3,4], "P must be 1, 2, 3, or 4"
        self.P = P
        self._cost = P  
        self.density = p 
    
    def generate_noise(self, x):
        """Generate uniform noise on the unit sphere.
        
        Args:
            x (numpy.ndarray): The point at which to estimate the gradient.
            
        Returns:
            numpy.ndarray: The noise vector.
        """
        ndim = len(x)
        vec = np.random.randn(ndim)
        vec /= np.linalg.norm(vec)
        return vec * np.sqrt(ndim)
    
    def _sample_n(self):
        return self.density.sample() 
    
    def _mu(self, n):
        mu = 0
        for i in range(1, n): 
            mu += self.density.pdf(i)
        return self.mu * (1 - mu)

    def _estimate(self, f, x, v, n):
        mu = self._mu(n) 
        noisy_x = x + mu * v
        diff = f(noisy_x) - f(x)
        return diff / mu  
    
    def P4_estimator(self, f, x):
        grad = np.zeros_like(x) 
        for i in range(self.zoo_batch_size):
            v = self.generate_noise(x)
            
            n = self._sample_n() 
            # calculate a_1
            grad += self._estimate(f, x, v, 1) * v
            
            # calculate a_n
            a_n = self._estimate(f, x, v, n)
            # calculate a_{n+1}
            a_nn = self._estimate(f, x, v, n+1)
            # calculate derivative
            dd = (a_nn - a_n) / self.density.pdf(n)
            grad += dd * v 
        return grad / self.zoo_batch_size 
    
    def P2_estimator(self, f, x):
        grad = np.zeros_like(x)
        for i in range(self.zoo_batch_size):
            v = self.generate_noise(x)
            n = self._sample_n()
            
            # calculate a_n
            a_n = self._estimate(f, x, v, n)
            # calculate a_{n+1}
            a_nn = self._estimate(f, x, v, n+1)
            # calculate derivative
            dd = (a_nn - a_n) / self.density.pdf(n)
            grad += dd * v
        return grad / self.zoo_batch_size * 2
    
    def P1_estimator(self, f, x):
        grad = np.zeros_like(x)
        for i in range(self.zoo_batch_size):
            v = self.generate_noise(x)
            n = self._sample_n()
            
            # calculate a_n
            a_n = self._estimate(f, x, v, n)
            # calculate derivative
            dd = a_n / self.density.pdf(n)
            grad += dd * v
        return grad / self.zoo_batch_size * 4
    
    def P3_estimator(self, f, x):
        grad = np.zeros_like(x)
        for i in range(self.zoo_batch_size):
            v = self.generate_noise(x)
            selection_variable = np.random.binomial(n=1, p=0.5)
            n = self._sample_n()
            
            if selection_variable:
                # calculate a_1
                grad += self._estimate(f, x, v, 1) * v
            else:
                # calculate a_n
                a_n = self._estimate(f, x, v, n)
                # calculate a_{n+1}
                a_nn = self._estimate(f, x, v, n+1)
                # calculate derivative
                dd = (a_nn - a_n) / self.density.pdf(n)
                grad += dd * v
        return grad / self.zoo_batch_size * 2
    
    def estimate(self, f, x):
        if self.P == 4:
            return self.P4_estimator(f, x)
        elif self.P == 2:
            return self.P2_estimator(f, x)
        elif self.P == 1:
            return self.P1_estimator(f, x)
        elif self.P == 3:
            return self.P3_estimator(f, x)
        else:
            raise ValueError("P must be 1, 2, 3, or 4")

