import math
import torch
import torch.nn as nn
import pdb
class ActivationSparsity(nn.Module):
    """Applies activation sparsity to the last dimension of input using K-winners strategy

    Args:
        alpha (float): constant used in updating duty-cycle
            Default: 0.1
        beta (float): boosting factor for neurons not activated in the previous duty cycle
            Default: 1.5
        act_sparsity (float): fraction of the input used in calculating K for K-Winners strategy
            Default: 0.65
    
    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input
        
    Examples::
    
        >>> x = asy.ActivationSparsity(10)
        >>> input = torch.randn(3,10)
        >>> output = x(input)
    """
    def __init__(self, alpha=0.1, beta=1.5, act_sparsity=0.65):
        super(ActivationSparsity, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.act_sparsity = act_sparsity
        self.duty_cycle = None
    
    def updateDC(self, inputs, duty_cycle):
        duty_cycle = (1 - self.alpha) * duty_cycle + self.alpha * (inputs.gt(0).sum(dim=0,dtype=torch.float))
        return duty_cycle

    def forward(self, inputs):
        #pdb.set_trace()
        in_features = inputs.shape[-1]
        out_shape=list(inputs.shape)
        inputs = inputs.reshape(inputs.shape[0],-1)

        device = inputs.device
       
        if self.duty_cycle is None:
            self.duty_cycle = torch.zeros(in_features, requires_grad=True).to(device)
        
        k = math.floor((1-self.act_sparsity) * in_features)
        with torch.no_grad():
            
            target = k / inputs.shape[-1]
            boost_coefficient = torch.exp(self.beta * (target - self.duty_cycle))
            boosted_input = inputs * boost_coefficient 
            
            # Get top k values 
            values, indices = boosted_input.topk( k, dim=-1, sorted=False)
            row_indices = torch.arange(inputs.shape[0]).repeat_interleave(k).view(-1,k)
            
        outputs = torch.zeros_like(inputs).to(device)
        outputs = outputs.index_put((row_indices, indices), inputs[row_indices, indices], accumulate=False) 
        
        if self.training:
            with torch.no_grad():
                self.duty_cycle = self.updateDC(outputs, self.duty_cycle)
        
        return outputs.view(out_shape)
    
    def extra_repr(self):
        return 'act_sparsity={}, alpha={}, beta={}, duty_cycle={}'.format(
            self.act_sparsity, self.alpha, self.beta, self.duty_cycle
        )