import jax
import jax.numpy as jnp
import chex


def softSPIBB_probs(q: jnp.ndarray, 
                    e: jnp.ndarray, 
                    b: jnp.ndarray, 
                    eps: jnp.float32):
    """Calculates the soft-SPIBB policy
    
    Args:
        q: vector of Q values
        e: vector of error values
        b: vector of behavior probabilities
        eps: scalar error tolerance
    Returns:
        vector of soft-SPIBB probabilities
    """
    chex.assert_rank([q,e,b,eps], [1, 1, 1, 0])
    chex.assert_type([q,e,b,eps], float)
    pi = b
    
    num_actions = len(q)
    sort_idx = jnp.argsort(q)
    
    sorted_q = q[sort_idx]
    sorted_e = e[sort_idx]
    sorted_b = b[sort_idx]
    
    allowed_error = eps
    for a_bot in range(num_actions):
        a_top = jnp.argmax( (q - sorted_q[a_bot]) / e)
        
        mass_bot = jnp.minimum(sorted_b[a_bot], allowed_error / (2 * sorted_e[a_bot]))
        mass_top = jnp.minimum(mass_bot, allowed_error / (2 * e[a_top]))
        mass_bot -= mass_top
        
        pi = jax.ops.index_add(pi, sort_idx[a_bot], -mass_top)
        pi = jax.ops.index_add(pi, a_top, mass_top)
       
        allowed_error -= mass_top * (sorted_e[a_bot] + e[a_top])

    return pi


def BCQ_probs(q: jnp.ndarray, 
                b: jnp.ndarray, 
                tau: jnp.float32):
    """Calculates the BCQ policy
    
    Args:
        q: vector of Q values
        e: vector of error values
        tau: scalar error tolerance, lower tau allows greater deviation

    Returns:
        jnp.ndarray: vector of BCQ probabilities
    """
    chex.assert_rank([q,b,tau], [1, 1, 0])
    chex.assert_type([q,b,tau], float)
    
    max_prob = jnp.max(b)
    min_q = jnp.min(q)

    filter = (jnp.sign( (b / max_prob) - tau) + 1) / 2.
    filtered_q = (q - min_q + 1.) * filter

    pi = jax.nn.one_hot(jnp.argmax(filtered_q), len(q))
    return pi


def to_qr(outputs, n_actions, n_quantiles):
    """ helper method for QR-DQN """ 
    batch_size = outputs.shape[0]
    q_dist = jnp.reshape(outputs, (batch_size, n_quantiles, n_actions))
    q_vals = jnp.mean(q_dist, axis=1)
    return q_vals, q_dist