import numpy as np
import torch
from torch.optim import Optimizer

class EWSG(Optimizer):

    def __init__(self, params, h, gamma, sigma):
        defaults = dict(h=h, gamma=gamma, sigma=sigma)
        super(EWSG, self).__init__(params, defaults)
        for group in self.param_groups:
            group['momentums'] = [param.new(torch.zeros_like(param)) for param in group['params']]
            group['grads'] = None
            
    def mh(self):
        accept = False
        for group in self.param_groups:
            h, gamma, sigma = group['h'], group['gamma'], group['sigma']
            energy = energy_proposal = 0.0

            for param, momen, grad in zip(group['params'], group['momentums'], group['grads']):
                x = np.sqrt(h) * gamma * momen.data / sigma
                energy += (x + np.sqrt(h) / sigma * grad.data).pow(2).sum()
                energy_proposal += (x + np.sqrt(h) / sigma * param.grad.data).pow(2).sum()

            diff = 0.5 * (energy_proposal - energy).cpu().numpy()
            if diff > 0 or np.random.rand() < np.exp(diff):
                accept = True
                for i, param in enumerate(group['params']):
                    group['grads'][i] = param.grad.clone().detach()
        
        return accept
    
    
    def accept(self):
        for group in self.param_groups:
            group['grads'] = [param.grad.clone().detach() for param in group['params']] 
        
        return True
        

    def step(self):
        for group in self.param_groups:
            h, gamma, sigma = group['h'], group['gamma'], group['sigma']

            for param, momen, grad in zip(group['params'], group['momentums'], group['grads']):
                param.data.add_(h, momen)

                momen.data.add_(h, - grad - gamma * momen)
                noise = momen.new(torch.randn_like(momen))
                momen.data.add_(np.sqrt(h) * sigma, noise)