import tensorflow as tf
from nn.layer import Dense, Relu, fwdd
from const import ns, na, nh
from util.rms import s_rms, a_rms

𝔼 = tf.reduce_mean

class GAIL:
  intype = 'sa'
  rtype  = 'sa'
  keys = ['dloss', 'gp_r']
  def __init__(self):
    self.net = net = []
    net += [Dense(ns+na,nh)]
    net += [Relu()]
    net += [Dense(nh,nh)]
    net += [Relu()]
    net += [Dense(nh,1)]
    self.vars = self._vars

  @tf.function
  def fwd(self, x):
    return fwdd(x, *self.vars)

  @tf.function
  def rwd(self, *args):
    args = [s_rms.nrm(args[0]),
            a_rms.nrm(args[1])]
    f = self.fwd(tf.concat(args,-1))
    return tf.math.softplus(f)

  @tf.function
  def loss_gp(self, sπ, aπ, sE, aE):
    nπ = tf.shape(sπ)[0]
    nE = tf.shape(sE)[0]

    s = tf.concat([sπ,sE],0)
    a = tf.concat([aπ,aE],0)
    s = s_rms.nrm(s)
    a = a_rms.nrm(a)

    lπ = tf.zeros(nπ)
    lE = tf.ones(nE)
    l = tf.concat([lπ,lE],0)

    f_in = tf.concat([s,a],-1)
    f = self.fwd(f_in)

    loss = 2 * 𝔼(tf.nn.sigmoid_cross_entropy_with_logits(l, f))

    u = tf.random.uniform([nπ, 1], 0., 1., tf.float32)

    sI = tf.stop_gradient(u * sπ + (1 - u) * tf.random.shuffle(sE))
    aI = tf.stop_gradient(u * aπ + (1 - u) * tf.random.shuffle(aE))
    sI_nrm = s_rms.nrm(sI)
    aI_nrm = a_rms.nrm(aI)

    f_inI = tf.concat([sI_nrm, aI_nrm],-1)

    fI = self.fwd(f_inI)
    grads = tf.gradients(fI, [sI, aI])
    gp = 𝔼(tf.concat([tf.reduce_max(tf.square(grad), 0) for grad in grads], -1))
    return loss, gp

  @property
  def _vars(self):
    ret = []
    for layer in self.net:
      ret.extend(layer.vars)
    return ret