import torch


class SafeInvariant(torch.nn.Module):
    def __init__(self, barrier, uncertainty, policy):
        super().__init__()
        self.barrier = barrier
        self.policy = policy
        self.uncertainty = uncertainty

    def U(self, states, actions=None):
        if actions is None:
            actions = self.policy(states)
        return self.uncertainty(states, actions)

    def L(self, states):
        return self.barrier(states)


