# https://github.com/xinyandai/gradient-quantization


import torch


class TopKSparsificationCompressor(object):
    def __init__(self, size, shape, args):
        self.cuda = not args.no_cuda
        self.size = size
        self.shape = shape
        self.users = 1
        self.k=int(size*(1-args.cr))

    def compress(self, vec):
        vec = vec.view(self.users, -1)
        ind = torch.zeros_like(vec)
        idx = torch.topk(torch.abs(vec), k=self.k, dim=1)[1]
        ind.scatter_(1, idx, 1)
        t = vec * ind
        return t

    def decompress(self, signature):
        return signature.view(self.shape)