import numpy as np
import torch
from torch.optim import Optimizer

class pEWSG(Optimizer):

    def __init__(self, params, h, gamma, sigma, lam=1e-5, alpha=0.99):
        defaults = dict(h=h, gamma=gamma, sigma=sigma, lam=lam, alpha=alpha)
        super(pEWSG, self).__init__(params, defaults)
        for group in self.param_groups:
            group['momentums'] = [param.new(torch.zeros_like(param)) for param in group['params']]
            group['grads'] = [None for param in group['params']]
            group['Vs'] = [param.new(torch.zeros_like(param)) for param in group['params']]
            group['Gs'] = [param.new(torch.zeros_like(param)) for param in group['params']] # G is a diagonal matrix, 
                                                                                            # we only maintain its diagonal in implementation
            
    def mh(self):
        accept = False
        for group in self.param_groups:
            h, gamma, sigma = group['h'], group['gamma'], group['sigma']
            energy = energy_proposal = 0.0

            for param, momen, grad, G in zip(group['params'], group['momentums'], group['grads'], group['Gs']):
                x = np.sqrt(h) * gamma * momen.data / sigma
                energy += (x + np.sqrt(h) / sigma * G * grad.data).pow(2).sum()
                energy_proposal += (x + np.sqrt(h) / sigma * G * param.grad.data).pow(2).sum()

            diff = 0.5 * (energy_proposal - energy).cpu().numpy()
            if diff > 0 or np.random.rand() < np.exp(diff):
                accept = True
                for i, param in enumerate(group['params']):
                    group['grads'][i] = param.grad.clone().detach()
        
        return accept
    
    
    def accept(self):
        for group in self.param_groups:
            group['grads'] = [param.grad.clone().detach() for param in group['params']] 
        
        return True
        

    def step(self):
        for group in self.param_groups:
            h, gamma, sigma = group['h'], group['gamma'], group['sigma']

            for param, momen, grad, G in zip(group['params'], group['momentums'], group['grads'], group['Gs']):
                param.data.add_(h, momen)

                momen.data.add_(h, - G * grad - gamma * momen)
                noise = momen.new(torch.randn_like(momen))
                momen.data.add_(np.sqrt(h) * sigma, torch.sqrt(G) * noise)


    def update_preconditioner(self):
        for group in self.param_groups:
            lam, alpha = group['lam'], group['alpha']
            for i, (grad, V, G) in enumerate(zip(group['grads'], group['Vs'], group['Gs'])):
                V = alpha * V + (1.0 - alpha) * grad**2
                G = 1.0 / (lam + torch.sqrt(V))

                # upate V and G
                group['Vs'][i] = V
                group['Gs'][i] = G
