import tensorflow as tf
from nn.layer import Dense, Relu, fwdd, fwdr
from const import γ, ns, na, nh
from util.rms import s_rms, a_rms

𝔼 = tf.reduce_mean

class AIRL2:
  intype = 'saś'
  keys = ['dloss', 'gp_r', 'gp_v']
  def __init__(self):
    self.gtype = 'sa'
    self.rtype = ['s', 'a', 'ś', 'logπ']
    self.g_net = g_net = []
    self.h_net = h_net = []

    ng_in = ns+na

    g_net += [Dense(ng_in,nh)]
    g_net += [Relu()]
    g_net += [Dense(nh,nh)]
    g_net += [Relu()]
    g_net += [Dense(nh,1)]

    h_net += [Dense(ns,nh)]
    h_net += [Relu()]
    h_net += [Dense(nh,nh)]
    h_net += [Relu()]
    h_net += [Dense(nh,1)]

    self.vars = self._vars
    self.g_vars = self._g_vars
    self.h_vars = self._h_vars

  @tf.function
  def fwdg(self, x):
    return fwdr(x, *self.g_vars)

  @tf.function
  def fwdh(self, x):
    return fwdd(x, *self.h_vars)

  @tf.function
  def rwd(self, *args):
    s, a, ś, logπ = args
    s = s_rms.nrm(s)
    a = a_rms.nrm(a)
    ś = s_rms.nrm(ś)
    g_in = tf.concat([s, a],-1)

    h = self.fwdh(s)
    g = self.fwdg(g_in)
    h́ = self.fwdh(ś)
    f = g + γ * h́ - h - logπ
    return f

  @tf.function
  def loss_gp(self, sπ, aπ, śπ, logpπ, sE, aE, śE, logpE):
    gtype = self.gtype

    nπ = tf.shape(sπ)[0]
    nE = tf.shape(sE)[0]

    s = tf.concat([sπ,sE],0)
    a = tf.concat([aπ,aE],0)
    ś = tf.concat([śπ,śE],0)
    logπ = tf.concat([logpπ,logpE],0)

    s = s_rms.nrm(s)
    a = a_rms.nrm(a)
    ś = s_rms.nrm(ś)

    lπ = tf.zeros(nπ)
    lE = tf.ones(nE)
    l = tf.concat([lπ, lE],0)
    g_in = tf.concat([s,a],-1)

    g = self.fwdg(g_in)
    h = self.fwdh(s)
    h́ = self.fwdh(ś)
    f = g + γ * h́ - h - tf.stop_gradient(logπ)
    loss_d = 2 * 𝔼(tf.nn.sigmoid_cross_entropy_with_logits(l, f))
    loss = loss_d

    u = tf.random.uniform([nπ, 1], 0., 1., tf.float32)

    sI = tf.stop_gradient(u * sπ + (1 - u) * tf.random.shuffle(sE))
    aI = tf.stop_gradient(u * aπ + (1 - u) * tf.random.shuffle(aE))
    sI_nrm = s_rms.nrm(sI)
    aI_nrm = a_rms.nrm(aI)

    g_inI  = [sI, aI]
    g_inI_nrm = tf.concat([sI_nrm, aI_nrm],-1)

    h_inI = sI_nrm
    gI = self.fwdg(g_inI_nrm)
    hI = self.fwdh(sI_nrm)
    grads_g = tf.gradients(gI, g_inI)
    grads_h = tf.gradients(hI, sI)
    gp_g = 𝔼(tf.concat([tf.reduce_max(tf.square(grad), 0) for grad in grads_g], -1))
    gp_h = 𝔼(tf.reduce_max(tf.square(grads_h), 0))
    return loss, gp_h, gp_g

  @property
  def _vars(self):
    ret = []
    for layer in self.g_net + self.h_net:
      ret.extend(layer.vars)
    return ret

  @property
  def _g_vars(self):
    ret = []
    for layer in self.g_net:
      ret.extend(layer.vars)
    return ret

  @property
  def _h_vars(self):
    ret = []
    for layer in self.h_net:
      ret.extend(layer.vars)
    return ret