import torch
from torch.nn import functional as F


class UniformBinaryString(torch.nn.Module):
    def __init__(self, num_bits, device, dtype):
        super().__init__()
        self.num_bits = num_bits

    def log_prob(self, inputs):
        bs = inputs.shape[0]  # batch
        return torch.full((bs,), -self.num_bits * torch.tensor(2).log())
