import torch
from torch import Tensor
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Optimizer
import numpy as np


class LKTDDA(Optimizer):
    def __init__(self, 
                params,
                lr=1e-4, 
                pseudo_population=50,
                obs_sd=1,
                state_sd=1,
                alpha=0.9
                ):

        defaults = dict(lr=lr, pseudo_population=pseudo_population, obs_sd=obs_sd, state_sd=state_sd, alpha=alpha)
        super().__init__(params, defaults)
    
    def update_parameters(self, parameter_name, parameter_value):
        for group in self.param_groups:
            group[parameter_name] = parameter_value
        
    def learning_rate(self, lr_init, k):
        # return lr_init / pow(k+1, 1)
        n = 10
        return lr_init * (n / (k + n)) ** 0.9

    @torch.no_grad()
    def step(self, aug_param:Variable, observation:Tensor, current_step:int, resample_net_params=None, closure=None):
        ''' One sigle step of LKTD algorithm
            observation (Tensor):  
            measurement (Tensor):
        '''
        if closure is not None:
            loss = closure()
        
        batch_size = aug_param.numel()
        batch_ratio = batch_size/self.defaults['pseudo_population']
        for group in self.param_groups:
            lr = self.learning_rate(group['lr'], current_step)
            B_t = lr
            w_sd = np.sqrt(B_t * batch_ratio)
            alpha_var_inv = 1/group['alpha']/group['obs_sd']**2
            
            if resample_net_params is None:
                for p in group['params']:
                    if p.grad is None:
                        continue
                    prior_grad = self._prior_gradient(p, group['state_sd'])
                    p.sub_(lr/2 * (alpha_var_inv * p.grad + prior_grad * batch_ratio) + w_sd * torch.randn_like(p, device=p.device)) 
                
            else:
                for p, q in zip(group['params'], resample_net_params):
                    if p.grad is None:
                        continue
                    transition_grad = self._transition_gradient(p, q, group['state_sd'])
                    p.sub_(lr/2 * (alpha_var_inv * p.grad + transition_grad * batch_ratio) + w_sd * torch.randn_like(p, device=p.device))
                
                
            
        
        ### Update augmented variables
        if aug_param.grad is not None:
            R_t = 2 * (1-self.defaults['alpha']) * self.defaults['obs_sd']**2
            v_sd = np.sqrt(batch_ratio * R_t) 
            k_var = B_t / (R_t + B_t)
            aug_param.sub_(lr/2 * batch_ratio * alpha_var_inv * aug_param.grad  + w_sd * torch.randn_like(aug_param, device=p.device))
            aug_param.add_(k_var * (observation-aug_param.detach() + v_sd * torch.randn_like(aug_param, device=p.device)) )


    def _transition_gradient(self, p, q, state_sd):
        return (p - q)/state_sd**2
    
    def _prior_gradient(self, p, state_sd):
        return p/state_sd**2
    
if __name__ == "__main__":

    pass