import numpy as np

class Sampler():
    
    def __init__(self, data, total_ensemble, h, gamma, sigma, seed=0):
        self.data = data                            # data of shape (N,)
        self.N = len(data)
        self.total_ensemble = total_ensemble
        self.h = h
        self.gamma = gamma
        self.sigma = sigma
        self.rng = np.random.RandomState(seed)
    
    def initialize(self):
            self.theta = self.rng.standard_normal(self.total_ensemble)
            self.r = np.zeros(self.total_ensemble) 
    
    def full_gradient(self):
        return np.sum(self.theta[:, np.newaxis] - self.data, axis=1)
    
    def stochastic_gradient(self):
        indices = self.rng.randint(0, self.N, size=self.total_ensemble)
        return self.N * (self.theta - self.data[indices]), indices
    
    def weighted_stochastic_gradient(self, M):
        grad, indices = self.stochastic_gradient()
        
        for m in range(M):
            grad_proposal, indices_proposal = self.stochastic_gradient()
            if self.integrator == 'GLA':
                if self.x_method == 'new':
                    diff = (self.aa * self.r + self.bb * grad_proposal)**2 - (self.aa * self.r + self.bb * grad)**2
                elif self.x_method == 'old':
                    diff = (self.bb * grad_proposal)**2 - (self.bb * grad)**2
            else:
                if self.x_method == 'old':
                    diff = self.h / self.sigma**2 * (grad_proposal**2 - grad**2)
                elif self.x_method == 'new':
                    diff = self.h / self.sigma**2 * ((self.gamma * self.r + grad_proposal)**2 - (self.gamma * self.r + grad)**2)
            ratio = np.exp(np.minimum(diff / 2, 0.0))
            accept = self.rng.uniform(size=self.total_ensemble) < ratio
            grad[accept], indices[accept] = grad_proposal[accept], indices_proposal[accept]

        return grad, indices

    def run_uld(self, grad):
        if self.integrator == 'EM':
            self.r += (-grad - self.gamma * self.r) * self.h + self.sigma * np.sqrt(self.h)* self.rng.standard_normal(self.total_ensemble)
            self.theta += self.r * self.h 
        elif self.integrator == 'GLA':
            # symplectic Euler integration of the Hamiltonian part
            self.r -= grad * self.h 
            self.theta += self.r * self.h

            # OU
            self.r = self.factor1 * self.r + self.factor2 * self.rng.standard_normal(self.total_ensemble)

    def sample_fg(self, n_data_pass, integrator):
        self.integrator = integrator
        if self.integrator == 'GLA':
            self.factor1 = np.exp(-self.gamma * self.h)
            self.factor2 = np.sqrt((1 - np.exp(-2 * self.gamma * self.h)))
            self.aa = self.gamma * self.h / self.factor2
            self.bb = self.factor1 * self.h / self.factor2

        max_iter = n_data_pass
        self.thetas = np.zeros((max_iter, self.total_ensemble))
        
        self.initialize()
        for i in range(max_iter):
            self.thetas[i] = self.theta
            
            grad = self.full_gradient()
            self.run_uld(grad)
        
        return self

    def stochastic_gradient2(self, minibatch_size):
        indices = self.rng.randint(0, self.N, size=self.total_ensemble*minibatch_size)
        return self.N / minibatch_size * (self.theta[:, np.newaxis] - self.data[indices].reshape((self.total_ensemble, minibatch_size))).sum(axis=1)  
    
    def sample_sg(self, n_data_pass, minibatch_size, x_method, integrator):
        self.x_method = x_method
        self.integrator = integrator
        if self.integrator == 'GLA':
            self.factor1 = np.exp(-self.gamma * self.h)
            self.factor2 = np.sqrt((1 - np.exp(-2 * self.gamma * self.h)))
            self.aa = self.gamma * self.h / self.factor2
            self.bb = self.factor1 * self.h / self.factor2

        max_iter = n_data_pass * self.N // minibatch_size
        self.thetas = np.zeros((max_iter, self.total_ensemble))

        self.initialize()
        for i in range(max_iter):
            self.thetas[i] = self.theta
            
            grad = self.stochastic_gradient2(minibatch_size)
            self.run_uld(grad)
        
        return self    
    
    def sample_ewsg(self, n_data_pass, M, x_method, integrator):
        self.x_method = x_method
        self.integrator = integrator
        if self.integrator == 'GLA':
            self.factor1 = np.exp(-self.gamma * self.h)
            self.factor2 = np.sqrt((1 - np.exp(-2 * self.gamma * self.h)))
            self.aa = self.gamma * self.h / self.factor2
            self.bb = self.factor1 * self.h / self.factor2

        max_iter = n_data_pass * self.N // (M + 1)
        self.thetas = np.zeros((max_iter, self.total_ensemble))

        self.initialize()
        for i in range(max_iter):
            self.thetas[i] = self.theta
            
            grad, _ = self.weighted_stochastic_gradient(M)
            self.run_uld(grad)
        
        return self


    # def update_gamma(self):
    #     theta = self.theta.mean(axis=0)
    #     r = self.r.mean(axis=0)

    #     gradients = self.N * (theta - self.data)
    #     log_weights_unnormalized = self.h / self.sigma**2 * np.sum((self.gamma * r + gradients)**2, axis=1)
    #     log_weights_unnormalized -= np.max(log_weights_unnormalized)
    #     weights_unnormalized = np.exp(log_weights_unnormalized)
    #     weights = weights_unnormalized / weights_unnormalized.sum()

    #     gradients_mean = np.average(gradients, axis=0, weights=weights)
    #     variance = np.average((gradients - gradients_mean)**2, axis=0, weights=weights)

    #     self.gamma = (self.h *  variance + self.sigma**2) / 2

    def compute_gradient_terms(self):
        return self.N * (self.theta[:, np.newaxis] - self.data)

    def update_gamma(self, grad, indices, alpha):
        self.gradients[np.arange(self.total_ensemble), indices] = grad
        m = self.gradients - self.gradients.mean(axis=1, keepdims=True)
        variance = np.square(m).sum(axis=1) / (self.N - 1)

        self.gamma = (alpha * self.h * variance + self.sigma**2) / 2

    def sample_ewsg_vr(self, n_data_pass, M, alpha):                # use SAG
        max_iter = (n_data_pass - 1) * self.N // (M + 1)
        self.thetas = np.zeros((max_iter, self.total_ensemble))

        self.initialize()
        self.gradients = self.compute_gradient_terms()
        self.gamma = self.gamma * np.ones(self.total_ensemble)

        for i in range(max_iter):
            self.thetas[i] = self.theta
            
            grad, indices = self.weighted_stochastic_gradient(M)
            self.run_uld(grad)

            self.update_gamma(grad, indices, alpha)
        
        return self


    def weights(self, gradient_terms):
        if self.uniform:
            return np.ones((self.total_ensemble, self.N)) / self.N
        if self.x_method == 'old':
            log_weights_unnormalized = self.h / self.sigma**2 * gradient_terms**2
        elif self.x_method == 'new':
            log_weights_unnormalized = self.h / self.sigma**2 * ((self.gamma * self.r)[:, np.newaxis] + gradient_terms)**2
        log_weights_unnormalized -= np.max(log_weights_unnormalized)
        weights_unnormalized = np.exp(log_weights_unnormalized)
        weights = weights_unnormalized / weights_unnormalized.sum(axis=1, keepdims=True)
        return weights


    def sample_ewsg_vr0(self, n_data_pass, M, alpha, x_method, uniform, integrator):           # each iteration uses one data pass to compute variance accurately
        self.x_method = x_method
        self.uniform = uniform
        self.integrator = integrator
        if self.integrator == 'GLA':
            self.factor1 = np.exp(-self.gamma * self.h)
            self.factor2 = np.sqrt((1 - np.exp(-2 * self.gamma * self.h)))
            self.aa = self.gamma * self.h / self.factor2
            self.bb = self.factor1 * self.h / self.factor2

        max_iter = n_data_pass * self.N // (M + 1)
        self.thetas = np.zeros((max_iter, self.total_ensemble))

        self.initialize()
        self.gamma = self.gamma * np.ones(self.total_ensemble)

        for i in range(max_iter):
            self.thetas[i] = self.theta
            
            grad, indices = self.weighted_stochastic_gradient(M)
            self.run_uld(grad)

            gradient_terms = self.compute_gradient_terms()
            # variance = np.var(gradient_terms, axis=1)
            weights = self.weights(gradient_terms)

            m1 = np.sum(gradient_terms * weights, axis=1)        # first moment
            m2 = np.sum(gradient_terms**2 * weights, axis=1)    # second moment
            variance = m2 - m1**2
            # variance = np.average((gradients - gradients_mean)**2, axis=0, weights=weights)
            self.gamma = (alpha * self.h * variance + self.sigma**2) / 2
        
        return self 



    def sample_ewsg_vr1(self, n_data_pass, M, alpha):                # use SAG
        max_iter = (n_data_pass - 1) * self.N // (M + 1)
        self.thetas = np.zeros((max_iter, self.total_ensemble))

        self.initialize()
        self.gradients = self.compute_gradient_terms()
        self.gamma = self.gamma * np.ones(self.total_ensemble)

        for i in range(max_iter):
            self.thetas[i] = self.theta
            
            grad, indices = self.weighted_stochastic_gradient(M)
            self.run_uld(grad)

            self.update_gamma(grad, indices, alpha)
        
        return self 


    def sample_ewsg_vr2(self, n_data_pass, M, L, alpha, x_method, uniform, integrator):           # use SVRG
        self.x_method = x_method
        self.uniform = uniform
        self.integrator = integrator
        
        max_iter = int(n_data_pass * L / (L + 1) * self.N) // (M + 1)
        lag = L * self.N // (M + 1)
        self.thetas = np.zeros((max_iter, self.total_ensemble))

        self.initialize()
        self.gamma = self.gamma * np.ones(self.total_ensemble)

        for i in range(max_iter):
            if i % lag == 0:
                old_gradients = self.compute_gradient_terms()
                weights = self.weights(old_gradients)
                self.m1 = np.sum(old_gradients * weights, axis=1)
                self.m2 = np.sum(old_gradients**2 * weights, axis=1)
                old_theta = self.theta

            self.thetas[i] = self.theta
            
            grad, indices = self.weighted_stochastic_gradient(M
)
            self.run_uld(grad)

            m1 = self.m1 + weights[np.arange(self.total_ensemble), indices] * (grad - self.N * (old_theta - self.data[indices]))
            m2 = self.m2 + weights[np.arange(self.total_ensemble), indices] * (grad**2 - (self.N * (old_theta - self.data[indices]))**2)
            variance = m2 - m1**2
            self.gamma = (alpha * self.h * variance + self.sigma**2) / 2
        
        return self 
