from .basegfn import BaseTBGFlowNet, tensor_to_np


class TBGFN(BaseTBGFlowNet):
  """ Trajectory balance GFN. Learns forward and backward policy. """
  def __init__(self, args, mdp, actor):
    super().__init__(args, mdp, actor)
    print('Model: TBGFN')

  def train(self, batch):
    return self.train_tb(batch)

def make_model(args, mdp, actor):
  model = TBGFN(args, mdp, actor)
  return model