import numpy as np
import torch
from torch.optim import Optimizer, SGD

class SGLD(Optimizer):

    def __init__(self, params, h):
        if h < 0.0:
            raise ValueError("Invalid step size: {}".format(h))
        defaults = dict(h=h)
        super(SGLD, self).__init__(params, defaults)

    def step(self, closure=None):
        if closure is not None:
            closure()

        for group in self.param_groups:
            h = group['h']
            for param in group['params']:
                if param.grad is None:
                    continue
                grad = param.grad.data
                noise = param.new(torch.randn_like(param))
                param.data.add_(-h * grad + np.sqrt(2 * h) * noise)

class pSGLD(Optimizer):

    def __init__(self, params, h, lam=1e-5, alpha=0.99):
        if h < 0.0:
            raise ValueError("Invalid step size: {}".format(h))
        defaults = dict(h=h, lam=lam, alpha=alpha)
        super(pSGLD, self).__init__(params, defaults)
        for group in self.param_groups:
            group['Vs'] = [param.new(torch.zeros_like(param)) for param in group['params']]
            group['Gs'] = [param.new(torch.zeros_like(param)) for param in group['params']] # G is a diagonal matrix, 
                                                                                            # we only maintain its diagonal in implementation

    def step(self, closure=None):
        if closure is not None:
            closure()

        for group in self.param_groups:
            h = group['h']
            for param, G in zip(group['params'], group['Gs']):
                if param.grad is None:
                    continue
                grad = param.grad.data
                noise = param.new(torch.randn_like(param))
                param.data.add_(-h * G * grad + np.sqrt(2 * h) * torch.sqrt(G) * noise)

    def update_preconditioner(self):
        for group in self.param_groups:
            lam, alpha = group['lam'], group['alpha']
            for i, (param, V, G) in enumerate(zip(group['params'], group['Vs'], group['Gs'])):
                V = alpha * V + (1.0 - alpha) * param.grad.data**2
                G = 1.0 / (lam + torch.sqrt(V))

                # upate V and G
                group['Vs'][i] = V
                group['Gs'][i] = G

class SGHMC(Optimizer):

    def __init__(self, params, h=0.1, gamma=1.0, sigma=np.sqrt(2.0), device=torch.device('cuda'), np_seed=0):
        defaults = dict(h=h, gamma=gamma, sigma=sigma, device=device)
        super(SGHMC, self).__init__(params, defaults)
        for group in self.param_groups:
            group['momentums'] = [torch.zeros_like(param, device=device) for param in group['params']]
            
        
    def step(self):
    
        with torch.no_grad():
            for group in self.param_groups:
                h= group['h']
                gamma = group['gamma']
                sigma = group['sigma']
                device = group['device']

                for param, momen in zip(group['params'], group['momentums']):
                    param.data.add_(h, momen)

                    momen.data.add_(h, - param.grad - gamma * momen)
                    noise = torch.randn_like(momen, device=device)
                    momen.data.add_(np.sqrt(h) * sigma, noise)

class EWSG(Optimizer):

    def __init__(self, params, h=0.1, gamma=1.0, sigma=np.sqrt(2.0), device=torch.device('cuda'), np_seed=0):
        
        self.rng = np.random.RandomState(np_seed)
        defaults = dict(h=h, gamma=gamma, sigma=sigma, device=device)
        super(EWSG, self).__init__(params, defaults)
        for group in self.param_groups:
            group['momentums'] = [torch.zeros_like(param, device=device) for param in group['params']]
            group['grads'] = None
            
    def mh(self):
        
        accept = False
        with torch.no_grad():
            for group in self.param_groups:
                h = group['h']
                gamma = group['gamma']
                sigma = group['sigma']

                energy = energy_proposal = 0.0

                for param, momen, grad in zip(group['params'], group['momentums'], group['grads']):
                    x = np.sqrt(h) * gamma * momen / sigma
                    energy += (x + np.sqrt(h) / sigma * grad).pow(2).sum()
                    energy_proposal += (x + np.sqrt(h) / sigma * param.grad).pow(2).sum()

                diff = 0.5 * (energy_proposal - energy).cpu().numpy()

                if diff > 0 or self.rng.rand() < np.exp(diff):
                    accept = True
                    for i, param in enumerate(group['params']):
                        group['grads'][i] = param.grad.clone().detach()
        
        return accept
    
    
    def accept(self):
        
        with torch.no_grad():
            for group in self.param_groups:
                group['grads'] = [param.grad.clone().detach() for param in group['params']] 
        
        return True
        
    def step(self):
        
        with torch.no_grad():
            for group in self.param_groups:
                h= group['h']
                gamma = group['gamma']
                sigma = group['sigma']
                device = group['device']

                for param, momen, grad in zip(group['params'], group['momentums'], group['grads']):
                    param.data.add_(h, momen)

                    momen.data.add_(h, - grad - gamma * momen)
                    noise = torch.randn_like(momen, device=device)
                    momen.data.add_(np.sqrt(h) * sigma, noise)