import jax.numpy as jnp
import jax

"""JAX numpy helper functions"""

def jit_decorator(func):
    """jit decorator function"""
    return jax.jit(func)

@jit_decorator
def to_jnp(arr):
    """jitted cast to jnp array function"""
    return jnp.array(arr, dtype=jnp.float32)

@jit_decorator
def to_jnp_int(arr):
    """jitted cast to jnp array function"""
    return jnp.array(arr, dtype=jnp.int32)

@jit_decorator
def compute_state_transition_matrix(state_action_transition_matrix, policy):
    """returns the state transition matrix given the state action transition matrix and a fixed (Markov) policy"""
    return jnp.sum(policy * state_action_transition_matrix, axis=2)

@jit_decorator
def compute_cond_action_matrix(state_action_transition_matrix, policy, state, action):
    """computes the conditional action state transition matrix"""
    state_transition_matrix = compute_state_transition_matrix(state_action_transition_matrix, policy)
    cond_action_matrix = jnp.zeros_like(state_transition_matrix)
    # overwrite the state distribution at the given state with the corresponding state distribution with the action fixed
    cond_action_matrix = cond_action_matrix.at[:, state].set(state_action_transition_matrix[:, state, action])
    return state_transition_matrix, cond_action_matrix

