import torch
from torch.distributions import Normal
from torch.optim.optimizer import Optimizer, required
import numpy as np



class SGLD(Optimizer):
    def __init__(self, params, lr=required, addnoise=True):
        defaults = dict(lr=lr, addnoise=addnoise)
        super(SGLD, self).__init__(params, defaults)

    def step(self, lr=None, add_noise=True, beta=50000):
        loss = None
        tot_trace = 0.0

        for group in self.param_groups:
            if lr:
                group['lr'] = lr
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if group['addnoise'] and add_noise:
                    # print('adding noise')
                    size = d_p.size()
                    ### gets mean and std as input:
                    langevin_noise = Normal(
                        torch.zeros(size),
                        torch.ones(size) * np.sqrt(2.*group['lr']/beta)
                        # torch.ones(size) * np.sqrt(2./beta)
                    )
                    noise = langevin_noise.sample().cuda()
                    # print('noise shape: ', noise.shape)
                    # print('noise std: ', np.sqrt(2.*group['lr']/beta))
                    # print('noise variance: ', 2.*group['lr']/beta)
                    # print('noise trace: ', torch.sum(noise))

                    # p.data.add_(d_p + langevin_noise.sample().cuda(), alpha=-group['lr'])
                    p.data.add_(d_p, alpha=-group['lr'])
                    # p.data.add_(langevin_noise.sample().cuda(), alpha=1.)
                    p.data.add_(noise, alpha=1.)
                else:
                    p.data.add_(d_p, alpha=-group['lr'])

        return loss