import jax
import jax.numpy as jnp
import chex

"""Functions from rlax (https://github.com/deepmind/rlax) rewritten using scan instead of python for loops for better performance."""

Array = chex.Array
Numeric = chex.Numeric

def lambda_returns(
    r_t: Array,
    discount_t: Array,
    v_t: Array,
    lambda_: Numeric = 1.,
    stop_target_gradients: bool = False,
) -> Array:
  chex.assert_rank([r_t, discount_t, v_t, lambda_], [1, 1, 1, {0, 1}])
  chex.assert_type([r_t, discount_t, v_t, lambda_], float)
  chex.assert_equal_shape([r_t, discount_t, v_t])

  # If scalar make into vector.
  lambda_ = jnp.ones_like(discount_t) * lambda_

  # # Work backwards to compute `G_{T-1}`, ..., `G_0`.
  def fn(return_t, x):
    r_t_, discount_t_, lambda_t_, v_t_ = x
    return_tm1 = r_t_ + discount_t_ * ((1 - lambda_t_) * v_t_ + lambda_t_ * return_t)
    return return_tm1, return_tm1

  _, returns = jax.lax.scan(fn, v_t[-1], (r_t, discount_t, lambda_, v_t), reverse=True)

  return jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(jnp.array(returns)), jnp.array(returns))


def discounted_returns(
    r_t: Array,
    discount_t: Array,
    v_t: Array,
    stop_target_gradients: bool = False,
) -> Array:
  chex.assert_rank([r_t, discount_t, v_t], [1, 1, {0, 1}])
  chex.assert_type([r_t, discount_t, v_t], float)

  # If scalar make into vector.
  bootstrapped_v = jnp.ones_like(discount_t) * v_t
  return lambda_returns(r_t, discount_t, bootstrapped_v, lambda_=1., stop_target_gradients=stop_target_gradients)


def general_off_policy_returns_from_q_and_v(
    q_t: Array,
    v_t: Array,
    r_t: Array,
    discount_t: Array,
    c_t: Array,
    stop_target_gradients: bool = False,
    scan=True,
) -> Array:
  """Calculates targets for various off-policy evaluation algorithms.
  Given a window of experience of length `K+1`, generated by a behaviour policy
  μ, for each time-step `t` we can estimate the return `G_t` from that step
  onwards, under some target policy π, using the rewards in the trajectory, the
  values under π of states and actions selected by μ, according to equation:
    Gₜ = rₜ₊₁ + γₜ₊₁ * (vₜ₊₁ - cₜ₊₁ * q(aₜ₊₁) + cₜ₊₁* Gₜ₊₁),
  where, depending on the choice of `c_t`, the algorithm implements:
    Importance Sampling             c_t = π(x_t, a_t) / μ(x_t, a_t),
    Harutyunyan's et al. Q(lambda)  c_t = λ,
    Precup's et al. Tree-Backup     c_t = π(x_t, a_t),
    Munos' et al. Retrace           c_t = λ min(1, π(x_t, a_t) / μ(x_t, a_t)).
  See "Safe and Efficient Off-Policy Reinforcement Learning" by Munos et al.
  (https://arxiv.org/abs/1606.02647).
  Args:
    q_t: Q-values under π of actions executed by μ at times [1, ..., K - 1].
    v_t: Values under π at times [1, ..., K].
    r_t: rewards at times [1, ..., K].
    discount_t: discounts at times [1, ..., K].
    c_t: weights at times [1, ..., K - 1].
    stop_target_gradients: bool indicating whether or not to apply stop gradient
      to targets.
  Returns:
    Off-policy estimates of the generalized returns from states visited at times
    [0, ..., K - 1].
  """
  chex.assert_rank([q_t, v_t, r_t, discount_t, c_t], 1)
  chex.assert_type([q_t, v_t, r_t, discount_t, c_t], float)
  chex.assert_equal_shape([q_t, v_t[:-1], r_t[:-1], discount_t[:-1], c_t])

  if scan:
    def fn(return_t, x):
      r_t_, discount_t_, v_t_, c_t_, q_t_ = x
      return_tm1 = r_t_ + discount_t_ * (v_t_ - c_t_ * q_t_ + c_t_ * return_t)
      return return_tm1, return_tm1

    g = r_t[-1] + discount_t[-1] * v_t[-1]
    _, returns = jax.lax.scan(fn, g, (r_t[:-1], discount_t[:-1], v_t[:-1], c_t, q_t), reverse=True)
    returns = jnp.concatenate([returns, g[None]])

  else:
    # Work backwards to compute `G_K-1`, ..., `G_1`, `G_0`.
    g = r_t[-1] + discount_t[-1] * v_t[-1]  # G_K-1.
    returns = [g]
    for i in reversed(range(q_t.shape[0])):  # [K - 2, ..., 0]
      g = r_t[i] + discount_t[i] * (v_t[i] - c_t[i] * q_t[i] + c_t[i] * g)
      returns.insert(0, g)

  return jax.lax.select(stop_target_gradients,
                        jax.lax.stop_gradient(jnp.array(returns)),
                        jnp.array(returns))
