import tensorflow as tf
from util.pd import entropy, logp
import time
from util import log
from const import α, γ, λ, η
from loss import π_vars, v_vars, πv_vars, fwd, fwdv
from const import ns, na, a_limit, a_scale

γ_inv = 1. / (1-γ)
#logp_uniform = - tf.reduce_sum(tf.math.log(2 * a_limit))
logp_uniform = - tf.reduce_sum(tf.math.log(2 * a_scale))

γλ = γ * λ
𝔼 = tf.reduce_mean
Σ = tf.add_n

@tf.function
def fn(a, x):
  δ = x[0]
  t = x[1]
  gae = tf.where(t, δ, δ + γλ * a)
  return gae

@tf.function
def get_logpπ(s, u):
  ps = fwd(s)
  return logp(*ps, u)

@tf.function
def get_logpE(s, a):
  ps = fwd(s)
  u = a_scale*tf.atanh(a/a_scale)
  return logp(*ps, u)

@tf.function
def get_adv_ret(s, u, r, ś, t, b, logπ):
  v = fwdv(s)
  v́ = tf.where(b, 0., fwdv(ś))
  r_ent = r - α * (logπ - logp_uniform)
  δ = r_ent + γ * v́ - v
  gae = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, (δ, t), initializer=0., reverse=True))
  v_tdλ = v + gae
  return gae, v_tdλ

@tf.function
def get_adv_ret_abs(s, u, r, ś, t, b, y, logπ):
  v = tf.where(y, 0., fwdv(s))
  v́ = tf.where(b, 0., fwdv(ś))
  r_ent = tf.where(y, r * γ_inv, r - α * (logπ - logp_uniform))
  δ = r_ent + γ * v́ - v
  gae = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, (δ, t), initializer=0., reverse=True))
  v_tdλ = v + gae
  return gae, v_tdλ

@tf.function
def πvgrad(s, u, r, logπ_old, adv, ret):
  ps = fwd(s)
  logπ = logp(*ps, u)
  logπ_diff = logπ - logπ_old
  ratio = tf.exp(logπ_diff)
  clip_ratio = tf.clip_by_value(ratio, .8, 1.2)
  adv_kl = adv - α

  pg_loss = - adv_kl * ratio
  cl_loss = - adv_kl * clip_ratio
  pi_loss = 𝔼(tf.maximum(pg_loss, cl_loss))

  v = fwdv(s)

  n = tf.shape(s)[0]
  z = tf.random.uniform([n, 1], 0., 1., tf.float32)
  sI = z * s + (1 - z) * tf.random.shuffle(s)
  vI = fwdv(sI)
  vI_grads = tf.gradients(vI, sI)
  vgp = tf.reduce_mean(tf.reduce_max(tf.square(vI_grads), 0))
  v_reg_loss = .5 * η * vgp

  v_loss = .5 * 𝔼(tf.square(ret - v))

  loss = Σ([pi_loss, v_loss, v_reg_loss])

  grads = tf.gradients(loss, πv_vars)
  πvgrad = tf.concat([tf.reshape(g, [-1]) for g in grads], -1)

  vals = tf.parallel_stack([pi_loss,
      𝔼(entropy(*ps)),
      𝔼(tf.cast(tf.greater(
          tf.abs(ratio - 1.), .19), tf.float32)),
      𝔼(-α * (logπ-logp_uniform)), v_loss, vgp])
  return πvgrad, vals

@tf.function
def πvgrad_abs(s, u, r, logπ_old, adv, ret, y):
  ps = fwd(s) # gmm
  logπ = logp(*ps, u)
  logπ_diff = tf.where(y, 0., logπ - logπ_old)
  ratio = tf.exp(logπ_diff)
  clip_ratio = tf.clip_by_value(ratio, .8, 1.2)
  adv_kl = adv - α

  pg_loss = - adv_kl * ratio
  cl_loss = - adv_kl * clip_ratio
  pi_loss = 𝔼(tf.maximum(pg_loss, cl_loss))

  v = tf.where(y, 0., fwdv(s))

  n = tf.shape(s)[0]
  z = tf.random.uniform([n, 1], 0., 1., tf.float32)
  sI = z * s + (1 - z) * tf.random.shuffle(s)
  vI = fwdv(sI)
  vI_grads = tf.gradients(vI, sI)
  vgp = tf.reduce_mean(tf.reduce_max(tf.square(vI_grads), 0))
  v_reg_loss = .5 * η * vgp

  v_loss = .5 * 𝔼(tf.square(ret - v))

  loss = Σ([pi_loss, v_loss, v_reg_loss])

  grads = tf.gradients(loss, πv_vars)
  πvgrad = tf.concat([tf.reshape(g, [-1]) for g in grads], -1)

  vals = tf.parallel_stack([pi_loss,
      𝔼(entropy(*ps)),
      𝔼(tf.cast(tf.greater(
          tf.abs(ratio - 1.), .19), tf.float32)),
      𝔼(-α * (logπ-logp_uniform)), v_loss, vgp])
  return πvgrad, vals