import torch


class UniformOutputModel(torch.nn.Module):
    def __init__(self, dims):
        super(UniformOutputModel, self).__init__()
        self.dims = dims

    def forward(self, x):
        return torch.ones(size=(x.shape[0], self.dims)) / self.dims

