from typing import Callable
import functools
import numpy as np
from tensorflow_probability.substrates import jax as tfp

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training.train_state import TrainState


from src.common import Params

tfd = tfp.distributions


@functools.partial(
    jax.jit,
    static_argnames=("actor_apply_fn", "critic_apply_fn", "action_dim"),
)
@functools.partial(
    jax.vmap, in_axes=(0, None, 0, None, 0, 0, None, None, None)
)
def _sample_actions_optimistic(
    rng: jax.random.PRNGKey,
    actor_apply_fn: Callable,
    actor_params: Params,
    critic_apply_fn: TrainState,
    critic_params: Params,
    observations: np.ndarray,
    action_dim: int,
    beta_ub: float,
    delta: float,
):

    rng, actor_key = jax.random.split(rng)

    dist, feats = actor_apply_fn(
        actor_params,
        observations,
        capture_intermediates=True,
        mutable=["intermediates"],
    )
    feats = feats["intermediates"]
    pre_tanh_mu_T = feats["Dense_0"]["__call__"][0]

    def grad_fn(pre_tanh_mu_T):

        pi_actions = nn.tanh(pre_tanh_mu_T)
        q1, q2 = critic_apply_fn(critic_params, observations, pi_actions)
        mu_q = (q1 + q2) / 2.0
        sigma_q = jnp.abs(q1 - q2) / 2.0
        q_ub = mu_q + beta_ub * sigma_q

        return q_ub.mean()

    grad = jax.grad(grad_fn)(pre_tanh_mu_T)

    # Obtain Sigma_T (the covariance of the normal distribution)
    std = dist.distribution.scale._diag
    sigma_t = jnp.pow(std, 2)

    # The dividor is (g^T Sigma g) ** 0.5
    # Sigma is diagonal, so this works out to be
    # ( sum_{i=1}^k (g^(i))^2 (sigma^(i))^2 ) ** 0.5
    denom = jnp.sqrt(jnp.sum(jnp.multiply(jnp.pow(grad, 2), sigma_t))) + 1e-5

    # obtain the change in mu
    mu_C = jnp.sqrt(2.0 * delta) * jnp.multiply(sigma_t, grad) / denom
    mu_E = pre_tanh_mu_T + mu_C

    # make distribution
    dist = tfd.Normal(loc=mu_E, scale=std)
    mu_E = dist.sample(seed=actor_key)
    ac = nn.tanh(mu_E)

    return rng, ac


def sample_actions_optimistic(
    rng: jax.random.PRNGKey,
    actor_apply_fn: Callable,
    actor_params: Params,
    critic_apply_fn: TrainState,
    critic_params: Params,
    observations: np.ndarray,
    action_dim: int,
    beta_ub: float,
    delta: float,
):

    return _sample_actions_optimistic(
        rng,
        actor_apply_fn,
        actor_params,
        critic_apply_fn,
        critic_params,
        observations,
        action_dim,
        beta_ub,
        delta,
    )
