from abc import ABC, abstractmethod
from torchtyping import TensorType

class BaseFlowFunction(ABC):
    @abstractmethod
    def update(self, agent) -> None:
        pass

    @abstractmethod
    def get_flows(
        self,
        states: TensorType['num_states']
    ) -> TensorType['num_states']:
        pass
