import torch
import torch.nn as nn
import copy
class Topk_layer(nn.Module):
    def __init__(self, sparsity):
        super().__init__()
        self.sparsity = sparsity

    def forward(self, x):
        x_mean = torch.mean(torch.abs(x),0)

        v,i = torch.topk(x_mean.flatten(), round(torch.numel(x_mean)*(1-self.sparsity)))

        self.weights = (torch.zeros(1, x.shape[1], x.shape[2], x.shape[3])).flatten()

        for i_idx in i:
            self.weights[i_idx]=1
        self.weights=torch.reshape(self.weights,[1, x.shape[1], x.shape[2], x.shape[3]])
        self.weights.detach()
        return x * self.weights.cuda()  # element-wise multiplication
