import torch

"""目标：按频率生成码表"""
"""功能：int->int"""


class ReCode:
    def __init__(self, nums):
        elements, counts = torch.unique(nums, return_counts=True)
        indices = torch.argsort(counts, descending=True)
        elements = elements[indices]
        counts = counts[indices]
        self.elements = elements
        self.counts = counts
        r_elements = torch.zeros(
            size=(torch.abs(elements.min()) + torch.abs(elements.max()),)
        ).long()
        for i, ele in enumerate(elements):
            r_elements[int(ele)] = i
        self.ele2code = r_elements
        self.code2ele = elements

    def encode(self, nums):
        return self.ele2code[nums].long()

    def decode(self, codes):
        return self.code2ele[codes].long()


if __name__ == "__main__":
    X = torch.LongTensor([1, 2, 2, 3, 3, 3])
    rc = ReCode(X)
    print(rc.encode(X))
    print(rc.decode(rc.encode(X)))
