import torch

class ADGPCompressor(object):
    def __init__(self, size, shape, args):
        self.cuda = not args.no_cuda
        self.size = size
        self.shape = shape
        self.users = 1
        self.k =  int(2 *size * (1 - (args.cr)))
        self.k1 = int(size * (args.cr) / (args.p + 1))

    def compress(self, vec):
        vec = vec.view(self.users, -1)
        ind = torch.zeros_like(vec)
        idx = torch.topk(torch.abs(vec), k=self.k, dim=1)[1]
        ind.scatter_(1, idx, 1)
        return vec * ind

    def compress1(self, vec):
        vec = vec.view(self.users, -1)
        ind = torch.zeros_like(vec)
        idx = torch.topk(torch.abs(vec), k=self.k, dim=1)[1]
        ind.scatter_(1, idx, 1)
        return vec, ind

    def compress2(self, vec, mask):
        vec = vec.view(self.users, -1)
        ind1 = torch.ones_like(vec)
        idx1 = torch.topk(torch.abs(vec), k=self.k1, dim=1)[1]
        ind1.scatter_(1, idx1, 0)
        vec = vec * ind1
        vec=vec * mask
        ind = torch.zeros_like(vec)
        idx = torch.topk(torch.abs(vec), k=( self.k )//2, dim=1)[1]
        ind.scatter_(1, idx, 1)
        return vec * ind,ind

    def decompress(self, signature):
        # print("##################\n", signature.view(self.shape), "\n##################")
        return signature.view(self.shape)

