from typing import Callable
from acme.jax import networks as networks_lib
import jax
import jax.numpy as jnp
import rlax

from rl_utils import softSPIBB_probs, BCQ_probs, to_qr

def epsilon_greedy_policy(network:  networks_lib.FeedForwardNetwork,
                    normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                    epsilon : jnp.float32 = 0.001):
    """policy docstring"""
    @jax.jit
    def policy(params: networks_lib.Params, 
               key: jnp.ndarray,
               observation: jnp.ndarray) -> jnp.ndarray:
        action_values = network.apply(params, normalize_fn(observation))
        return rlax.epsilon_greedy(epsilon).sample(key, action_values)
    return policy


def softmax_policy(network:  networks_lib.FeedForwardNetwork,
                    normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                    epsilon: jnp.float32 = 0.001):
    """policy docstring"""
    @jax.jit
    def policy(params: networks_lib.Params, 
               key: jnp.ndarray,
               observation: jnp.ndarray) -> jnp.ndarray:
        logits = network.apply(params, normalize_fn(observation))
        return rlax.epsilon_softmax(epsilon, 1.0).sample(key, logits)
    return policy


def soft_spibb_policy(q_network: networks_lib.FeedForwardNetwork, 
                        error_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                        behavior_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                        normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                        epsilon: jnp.float32,
                        min_prob: jnp.float32 = 0.0):
    """policy docstring"""
    @jax.jit
    def policy(params: networks_lib.Params, 
            key: jnp.ndarray,
            observation: jnp.ndarray) -> jnp.ndarray:
        q = q_network.apply(params, normalize_fn(observation))
        e = error_fn(observation)
        b = behavior_fn(observation)
        eps = epsilon * jnp.ones((q.shape[0],))
        batch_softSPIBB_probs = jax.vmap(softSPIBB_probs)
        pi = batch_softSPIBB_probs(q, e, b, eps)
        pi += min_prob
        return rlax.categorical_sample(key, pi)
    return policy

def greedy_spibb_policy(q_network: networks_lib.FeedForwardNetwork, 
                        error_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                        behavior_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                        normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                        epsilon: jnp.float32,
                        greedy_epsilon: jnp.float32 = 0.001):
    """policy docstring"""
    @jax.jit
    def policy(params: networks_lib.Params, 
            key: jnp.ndarray,
            observation: jnp.ndarray) -> jnp.ndarray:
        q = q_network.apply(params, normalize_fn(observation))
        e = error_fn(observation)
        b = behavior_fn(observation)
        eps = epsilon * jnp.ones((q.shape[0],))
        batch_softSPIBB_probs = jax.vmap(softSPIBB_probs)
        pi = batch_softSPIBB_probs(q, e, b, eps)
        return rlax.epsilon_greedy(greedy_epsilon).sample(key, pi)
    return policy


def bcq_policy(q_network: networks_lib.FeedForwardNetwork, 
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                tau: jnp.float32,
                epsilon: jnp.float32 = 0.001):
    """policy docstring"""
    @jax.jit
    def policy(params: networks_lib.Params, 
            key: jnp.ndarray,
            observation: jnp.ndarray) -> jnp.ndarray:
        q = q_network.apply(params, normalize_fn(observation))
        b = behavior_fn(observation)
        t = tau * jnp.ones((q.shape[0],))
        batch_BCQ_probs = jax.vmap(BCQ_probs)
        pi = batch_BCQ_probs(q, b, t)
        return rlax.epsilon_greedy(epsilon).sample(key, pi)
    return policy


def epsilon_greedy_policy_qr(network:  networks_lib.FeedForwardNetwork,
                    normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                    n_actions: int,
                    n_quantiles: int = 201,
                    epsilon : jnp.float32 = 0.001):
    """policy docstring"""
    @jax.jit
    def policy(params: networks_lib.Params, 
               key: jnp.ndarray,
               observation: jnp.ndarray) -> jnp.ndarray:
        action_values, _ = to_qr(network.apply(params, normalize_fn(observation)),
                                n_actions, n_quantiles)
        return rlax.epsilon_greedy(epsilon).sample(key, action_values)
    return policy


def greedy_spibb_policy_qr(q_network: networks_lib.FeedForwardNetwork, 
                        error_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                        behavior_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                        normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                        epsilon: jnp.float32,
                        n_actions: int,
                        n_quantiles: int = 201,
                        greedy_epsilon: jnp.float32 = 0.001):
    """policy docstring"""
    @jax.jit
    def policy(params: networks_lib.Params, 
            key: jnp.ndarray,
            observation: jnp.ndarray) -> jnp.ndarray:
        q, _ = to_qr(q_network.apply(params, normalize_fn(observation)), 
                        n_actions, n_quantiles)
        e = error_fn(observation)
        b = behavior_fn(observation)
        eps = epsilon * jnp.ones((q.shape[0],))
        batch_softSPIBB_probs = jax.vmap(softSPIBB_probs)
        pi = batch_softSPIBB_probs(q, e, b, eps)
        return rlax.epsilon_greedy(greedy_epsilon).sample(key, pi)
    return policy

def soft_spibb_policy_qr(q_network: networks_lib.FeedForwardNetwork, 
                        error_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                        behavior_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                        normalize_fn: Callable[[jnp.ndarray], jnp.ndarray],
                        epsilon: jnp.float32,
                        n_actions: int,
                        n_quantiles: int = 201):
    """policy docstring"""
    @jax.jit
    def policy(params: networks_lib.Params, 
            key: jnp.ndarray,
            observation: jnp.ndarray) -> jnp.ndarray:
        q, _ = to_qr(q_network.apply(params, normalize_fn(observation)), 
                        n_actions, n_quantiles)
        e = error_fn(observation)
        b = behavior_fn(observation)
        eps = epsilon * jnp.ones((q.shape[0],))
        batch_softSPIBB_probs = jax.vmap(softSPIBB_probs)
        pi = batch_softSPIBB_probs(q, e, b, eps)
        return rlax.categorical_sample(key, pi)
    return policy


def bcq_policy_qr(q_network: networks_lib.FeedForwardNetwork, 
                behavior_fn: Callable[[jnp.ndarray], jnp.ndarray],
                normalize_fn: Callable[[jnp.ndarray], jnp.ndarray], 
                tau: jnp.float32,
                n_actions: int,
                n_quantiles: int = 201,
                epsilon: jnp.float32 = 0.001):
    """policy docstring"""
    @jax.jit
    def policy(params: networks_lib.Params, 
            key: jnp.ndarray,
            observation: jnp.ndarray) -> jnp.ndarray:
        q, _ = to_qr(q_network.apply(params, normalize_fn(observation)), 
                        n_actions, n_quantiles)
        b = behavior_fn(observation)
        t = tau * jnp.ones((q.shape[0],))
        batch_BCQ_probs = jax.vmap(BCQ_probs)
        pi = batch_BCQ_probs(q, b, t)
        return rlax.epsilon_greedy(epsilon).sample(key, pi)
    return policy