import torch
from .diffusionGrid_sampling import forward_reward, backward_trajectory_log_prob
from torch.nn.functional import softplus


def tb(fnet, bnet, logz, env):
    gfn_log_r = forward_reward(env, fnet, bnet, logz)
    log_r = env.log_reward()
    return (gfn_log_r - log_r).pow(2).mean()
