π_vars = None
v_vars = None
πv_vars = None
fwd = None
fwdv = None

def set_πv_loss(π_variables, v_variables, π_fwd, π_fwdv):
  global π_vars, v_vars, πv_vars, fwd, fwdv
  π_vars = π_variables
  v_vars = v_variables
  πv_vars = π_variables + v_variables
  fwd = π_fwd
  fwdv = π_fwdv

d_vars = None
d_fn = None

def set_d_loss(d_variables, d_loss_grad_fn):
  global d_vars, d_fn
  d_vars = d_variables
  d_fn = d_loss_grad_fn

