import tensorflow as tf
import numpy as np
from const import a_scale

Σ = tf.add_n

entropy = None
logp = None

ent_const = tf.constant(.5 * np.log(2.0 * np.pi * np.e), tf.float32)
logp_const = tf.constant(.5 * np.log(2.0 * np.pi), tf.float32)
squash_const = tf.constant(np.log(4.0), tf.float32)

@tf.function
def entropy_gaussian(mean, logstd):
  return tf.reduce_sum(logstd + ent_const, axis=-1)

@tf.function
def squash_logp_gaussian(mean, logstd, u):
  std = tf.exp(logstd)
  logdet_inv = - (tf.math.log_sigmoid(-2 * u / a_scale) + tf.math.log_sigmoid(2 * u / a_scale) + squash_const)

  return tf.math.negative(.5 * tf.reduce_sum(tf.square((u - mean) / std), -1) + logp_const * tf.cast(tf.shape(u)[-1], tf.float32) + tf.reduce_sum(logstd, -1)) + tf.reduce_sum(logdet_inv, -1)

@tf.function
def logp_gaussian(mean, logstd, x):
  std = tf.exp(logstd)
  return tf.math.negative(.5 * tf.reduce_sum(tf.square((x - mean) / std), -1) + logp_const * tf.cast(tf.shape(x)[-1], tf.float32) + tf.reduce_sum(logstd, -1))

@tf.function
def entropy_gmm(zlogits, means, logstds):
  return tf.reduce_sum(tf.nn.softmax(zlogits) * (tf.reduce_sum(logstds + ent_const, -1) - tf.nn.log_softmax(zlogits)), -1)

@tf.function
def squash_logp_gmm(zlogits, means, logstds, u):
  logpz = tf.nn.log_softmax(zlogits)
  us = tf.expand_dims(u, 1)
  stds = tf.exp(logstds)
  logps = logpz + tf.math.negative(.5 * tf.reduce_sum(tf.square((us - means) / stds), -1) + logp_const * tf.cast(tf.shape(u)[-1], tf.float32) + tf.reduce_sum(logstds, -1))
  #x = tf.tanh(u)

  #logdet_inv = - (tf.math.log_sigmoid(-2 * u) + tf.math.log_sigmoid(2 * u) + squash_const)
  logdet_inv = - (tf.math.log_sigmoid(-2 * u / a_scale) + tf.math.log_sigmoid(2 * u / a_scale) + squash_const)
  return tf.math.reduce_logsumexp(logps, -1) + tf.reduce_sum(logdet_inv, -1)

@tf.function
def logp_gmm(zlogits, means, logstds, x):
  logpz = tf.nn.log_softmax(zlogits)
  xs = tf.expand_dims(x, 1)
  stds = tf.exp(logstds)
  logps = logpz + tf.math.negative(.5 * tf.reduce_sum(tf.square((xs - means) / stds), -1) + logp_const * tf.cast(tf.shape(x)[-1], tf.float32) + tf.reduce_sum(logstds, -1))
  return tf.math.reduce_logsumexp(logps, -1)

def set_pd(policy_type):
  global entropy, logp
  if policy_type == 'gaussian':
    entropy = entropy_gaussian
    logp = squash_logp_gaussian
  elif policy_type == 'gmm':
    entropy = entropy_gmm
    logp = squash_logp_gmm
