import tensorflow as tf
from nn.layer import Dense, Relu, fwdd, fwdr
from const import α, γ, ns, na, nh
from util.rms import s_rms, a_rms
from loss.πloss import logp_uniform

𝔼 = tf.reduce_mean
plus = tf.math.softplus

class CAIRL:
  intype = 'saśb'
  keys = ['dloss', 'gp_s', 'gp_r', 'gp_v']
  def __init__(self):
    self.rtype = 'sa'
    self.ρ_net = ρ_net = []
    self.r_net = r_net = []
    self.v_net = v_net = []

    nr_in = ns + na

    ρ_net += [Dense(ns,nh)]
    ρ_net += [Relu()]
    ρ_net += [Dense(nh,nh)]
    ρ_net += [Relu()]
    ρ_net += [Dense(nh,1)]

    r_net += [Dense(nr_in,nh)]
    r_net += [Relu()]
    r_net += [Dense(nh,nh)]
    r_net += [Relu()]
    r_net += [Dense(nh,1)]

    v_net += [Dense(ns,nh)]
    v_net += [Relu()]
    v_net += [Dense(nh,nh)]
    v_net += [Relu()]
    v_net += [Dense(nh,1)]

    self.vars = self._vars
    self.ρ_vars = self._ρ_vars
    self.r_vars = self._r_vars
    self.v_vars = self._v_vars

  @tf.function
  def fwdρ(self, x):
    return fwdd(x, *self.ρ_vars)

  @tf.function
  def fwdr(self, x):
    return fwdr(x, *self.r_vars)

  @tf.function
  def fwdv(self, x):
    return fwdd(x, *self.v_vars)

  @tf.function
  def rwd(self, s, a):
    s = s_rms.nrm(s)
    a = a_rms.nrm(a)
    r = plus(self.fwdr(tf.concat([s,a],-1)))
    return α * r

  @tf.function
  def loss_gp(self, sπ, aπ, śπ, bπ, v́π_old, logpπ, sE, aE, śE, bE, v́E_old, logpE):
    rtype = self.rtype

    nπ = tf.shape(sπ)[0]
    nE = tf.shape(sE)[0]

    s = tf.concat([sπ,sE],0)
    a = tf.concat([aπ,aE],0)
    ś = tf.concat([śπ,śE],0)
    b = tf.concat([bπ,bE],0)

    logπ = tf.concat([logpπ,logpE],0)

    s = s_rms.nrm(s)
    ś = s_rms.nrm(ś)
    a = a_rms.nrm(a)

    lπ = tf.zeros(nπ)
    lE = tf.ones(nE)
    l = tf.concat([lπ, lE],0)

    ρ = self.fwdρ(s)
    loss_ρ = 2 * 𝔼(tf.nn.sigmoid_cross_entropy_with_logits(l, ρ))
    r = plus(self.fwdr(tf.concat([s,a],-1)))
    v = self.fwdv(s)
    v́ = tf.stop_gradient(tf.where(b, 0., self.fwdv(ś)))

    f = r + γ * v́ - v + tf.stop_gradient(ρ - logπ + logp_uniform)
    loss_d = 2 * 𝔼(tf.nn.sigmoid_cross_entropy_with_logits(l, f))
    loss_v = 0.00005 * 𝔼(tf.square(γ * v́ - v))
    loss = loss_ρ + loss_d + loss_v

    u = tf.random.uniform([nπ, 1], 0., 1., tf.float32)
    sI = tf.stop_gradient(u * sπ + (1 - u) * tf.random.shuffle(sE))
    aI = tf.stop_gradient(u * aπ + (1 - u) * tf.random.shuffle(aE))
    sI_nrm = s_rms.nrm(sI)
    aI_nrm = a_rms.nrm(aI)
    ρI = self.fwdρ(sI_nrm)
    rI = self.fwdr(tf.concat([sI_nrm, aI_nrm],-1))
    vI = self.fwdv(sI_nrm)

    grads_ρ = tf.gradients(ρI, sI)
    grads_r = tf.gradients(rI, [sI, aI])
    grads_v = tf.gradients(vI, sI)

    gp_ρ = 𝔼(tf.reduce_max(tf.square(grads_ρ), 0))
    gp_r = 𝔼(tf.concat([tf.reduce_max(tf.square(grad), 0) for grad in grads_r], -1))
    gp_v = 𝔼(tf.reduce_max(tf.square(grads_v), 0))
    return loss, gp_ρ, gp_r, gp_v

  @property
  def _vars(self):
    ret = []
    for layer in self.ρ_net + self.r_net + self.v_net:
      ret.extend(layer.vars)
    return ret

  @property
  def _ρ_vars(self):
    ret = []
    for layer in self.ρ_net:
      ret.extend(layer.vars)
    return ret

  @property
  def _r_vars(self):
    ret = []
    for layer in self.r_net:
      ret.extend(layer.vars)
    return ret

  @property
  def _v_vars(self):
    ret = []
    for layer in self.v_net:
      ret.extend(layer.vars)
    return ret

