import torch
from torch import Tensor
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Optimizer
import numpy as np


class SGHMC(Optimizer):
    def __init__(self, 
                params,
                lr=1e-3, 
                pseudo_population=1,
                alpha=1.0,
                power=0.0,
                ):

        defaults = dict(lr=lr, pseudo_population=pseudo_population, alpha=alpha, power=power)
        super().__init__(params, defaults)
        self.pseudo_population = pseudo_population
        self.power = power
        
    def update_parameters(self, parameter_name, parameter_value):
        for group in self.param_groups:
            group[parameter_name] = parameter_value
            
    def learning_rate(self, lr_init, k):
        lr = lr_init / (1 + k) ** self.power
        return lr

    @torch.no_grad()
    def step(self, current_step=0, closure=None):
        ''' One sigle step of LKTD algorithm
            observation (Tensor):  
            measurement (Tensor):
        '''
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            # lr = group['lr']
            lr =self.learning_rate(group['lr'], current_step)
            alpha = group['alpha']
            w_sd = np.sqrt(lr / self.pseudo_population)
            for p in group['params']:
                if p.grad is None:
                    continue
                state = self.state[p]
                if current_step == 0:
                    state['momentum'] = torch.zeros_like(p)
                if 'momentum' not in state:
                    state['momentum'] = torch.zeros_like(p)
                v = state['momentum']
                v = (1 - alpha) * v + lr * p.grad + np.sqrt(2 * alpha) * w_sd * torch.randn_like(p, device=p.device)
                p.sub_(v)


if __name__ == "__main__":
    pass