import torch
from torch import Tensor
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Optimizer
import numpy as np


class LKTD(Optimizer):
    def __init__(self, 
                params,
                lr=1e-4, 
                sgld_temperature=1,
                obs_sd=1,
                prior_sd=1,
                sparse_sd=0.01,
                sparse_ratio=1,
                alpha=0.9
                ):

        defaults = dict(lr=lr, sgld_temperature=sgld_temperature, obs_sd=obs_sd, prior_sd=prior_sd, sparse_sd=sparse_sd, sparse_ratio=sparse_ratio, alpha=alpha)
        super().__init__(params, defaults)

    def learning_rate(self, lr_init, k):
        return lr_init / pow(k+1, 1)

    @torch.no_grad()
    def step(self, aug_param:Variable, observation:Tensor, current_step:int, closure=None):
        ''' One sigle step of LKTD algorithm
            observation (Tensor):  
            measurement (Tensor):
        '''
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            lr = self.learning_rate(group['lr'], current_step)
            w_sd = np.sqrt(lr * group['sgld_temperature'])
            alpha_var_inv = 1/group['alpha']/group['obs_sd']
            for p in group['params']:
                if p.grad is None:
                    continue
                prior_grad = self._prior_gradient(p, group['prior_sd'], group['sparse_sd'], group['sparse_ratio'])
                p.sub_(lr/2 * (alpha_var_inv * p.grad + prior_grad) + w_sd * torch.randn_like(p, device=p.device))
                # p.add_(torch.normal(mean=lr * p.grad /2, std=w_sd))
            
        
        ### Update augmented variables
        r_sd = np.sqrt(2 * (1-self.defaults['alpha']) * self.defaults['sgld_temperature'] ) * self.defaults['obs_sd']
        k_var = w_sd**2 / (w_sd**2 + r_sd**2)
        aug_param.sub_(lr/2 * alpha_var_inv * aug_param.grad  + w_sd * torch.randn_like(aug_param, device=p.device))
        aug_param.add_(k_var * (observation-aug_param.detach() + r_sd * torch.randn_like(aug_param, device=p.device)) )


    def _prior_gradient(self, param, prior_sd, sparse_sd, sparse_ratio):
        param_trunc = param.clamp(min=-prior_sd, max=prior_sd)
        if sparse_ratio == 1:
            return param/prior_sd**2
        elif sparse_ratio < 1:
            A = sparse_ratio/(1-sparse_ratio) * sparse_sd/prior_sd * torch.exp(-(1/prior_sd**2 - 1/sparse_sd**2)*torch.square(param_trunc)/2)
            coef = 1/prior_sd**2 + torch.div((1/sparse_sd**2 - 1/prior_sd**2),A)
            return coef * param

if __name__ == "__main__":
    # def prior_gradient(param, prior_sd, sparse_sd, sparse_ratio):
    #     param_trunc = param.clamp(min=-prior_sd, max=prior_sd)
    #     if sparse_ratio == 1:
    #         return -prior_sd * param
    #     elif sparse_ratio < 1:
    #         A = sparse_ratio/(1-sparse_ratio) * sparse_sd/prior_sd * torch.exp(-(1/prior_sd**2 - 1/sparse_sd**2)*torch.square(param_trunc)/2)
    #         coef = 1/prior_sd**2 + torch.div((1/sparse_sd**2 - 1/prior_sd**2),A)
    #         return -coef * param
    
    # import matplotlib.pyplot as plt
    # x = torch.tensor([-0.05, 0.1, 0.2, 0.3, 0.4, 1, 2, 3])
    # y = prior_gradient(x, 0.5, 1, 0.1)
    # print(y)
    # xs = torch.linspace(-3, 3, 1000)
    # ys = prior_gradient(xs, 1, 0.1, 0.1)
    # plt.plot(xs, ys)
    # plt.show()
    pass