import torch


class Predicate:

    def __init__(self, name: str, grounding: torch.Tensor):
        self.name = name
        
        assert all(d == grounding.shape[0] for d in grounding.shape)
        self.grounding = grounding

        self.arity = len(grounding.shape)
        assert self.arity > 0

        self.num_constants = grounding.shape[0]

    def to(self, device) -> None:
        self.grounding = self.grounding.to(device=device)

    def copy_log(self):
        return Predicate(self.name + "_log", torch.log(self.grounding))
