import torch

"""目标：参数的量化"""
"""功能：float->int"""


class BitQuantizer:
    def __init__(self, bitNum, data):
        self.bitNum = bitNum
        self.Z = data.mean()
        quantInter = (1 << bitNum * 8) - 2
        dataInter = 2 * max(abs(data.max() - self.Z), abs(data.min() - self.Z))
        self.rate = quantInter / dataInter

    def quant(self, data):
        return torch.round(self.rate * (data - self.Z)).long()

    def quant_inv(self, symbols):
        return symbols / self.rate + self.Z


class LinearQuantizer:
    def __init__(self, data, m=8):
        self.m = m
        self.rate = 2 ** m

    def quant(self, data):
        return torch.round(self.rate * data).long()

    def quant_inv(self, symbols):
        return symbols / self.rate


if __name__ == "__main__":
    N = 1e5
    bitNum = 2
    m = 3

    # 1. Generate Data
    data = torch.randn(int(N))

    # 2. Generate Quantizer
    # lq = BitQuantizer(bitNum, data)
    lq = LinearQuantizer(m, data)

    # 3. Process
    syms = lq.quant(data)
    print(len(set(syms)))
    dataQuant = lq.quant_inv(syms)

    # 4. Print
    print(data)
    print(syms.max(), syms.min())
    print(dataQuant)
