import tensorflow as tf
from const import η, γ
from loss import d_vars, d_fn

Σ = tf.add_n

@tf.function
def gail_grad(*args):
  d_loss, gp = d_fn(*args)
  reg_loss = 0.5 * η * gp
  loss = Σ([d_loss, reg_loss])
  grads = tf.gradients(loss, d_vars)
  grad = tf.concat([tf.reshape(g, [-1]) for g in grads], -1)
  vals = tf.parallel_stack([d_loss, gp])
  return grad, vals

@tf.function
def fairl_grad(*args):
  d_loss, gp = d_fn(*args)
  reg_loss = 0.5 * η * gp
  loss = Σ([d_loss, reg_loss])
  grads = tf.gradients(loss, d_vars)
  grad = tf.concat([tf.reshape(g, [-1]) for g in grads], -1)
  vals = tf.parallel_stack([d_loss, gp])
  return grad, vals

@tf.function
def airl_grad(*args):
  d_loss, gp_r, gp_v = d_fn(*args)
  gp_loss1 = 0.5 * η * gp_r
  gp_loss2 = 0.5 * η * gp_v
  loss = Σ([d_loss, gp_loss1, gp_loss2])
  grads = tf.gradients(loss, d_vars)
  grad = tf.concat([tf.reshape(g, [-1]) for g in grads], -1)
  vals = tf.parallel_stack([d_loss, gp_r, gp_v])
  return grad, vals

@tf.function
def cairl_grad(*args):
  d_loss, gp_ρ, gp_r, gp_v = d_fn(*args)
  gp_loss1 = 0.5 * η * gp_ρ
  gp_loss2 = 0.5 * η * gp_r
  gp_loss3 = 0.5 * η * gp_v
  loss = Σ([d_loss, gp_loss1, gp_loss2, gp_loss3])
  grads = tf.gradients(loss, d_vars)
  grad = tf.concat([tf.reshape(g, [-1]) for g in grads], -1)
  vals = tf.parallel_stack([d_loss, gp_ρ, gp_r, gp_v])
  return grad, vals