import torch


class LabelDist(torch.nn.Module):
    def __init__(self, device: torch.device):
        super().__init__()
        self.device = device

        self.scale = torch.nn.Parameter(torch.ones(1))
        self.to(self.device)

    def forward(self, graph):
        return graph['attr_matrix'] * self.scale
