import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from .jax_helpers import jit_decorator

"""helper functions"""

@jit_decorator
def sample_step(key, probs):
    """jitted random choice function"""
    key, subkey = jax.random.split(key)
    return key, jax.random.choice(subkey, probs.shape[0], p=probs)

"""jitted sampling functions with conditional actions"""

@jit_decorator
def sample_path_cd(key, state_action_transition_matrix, policy, start_state, path, cond_action):
    """sample a path under the transition matrix from a given start state and conditional action"""
    path = path.at[0].set(start_state)

    def loop_body(args):
        key, i, path = args
        key, next_state = sample_step(key, jnp.sum(policy[path[i-1]] * state_action_transition_matrix[:, path[i-1]], axis=1)) 
        return (key, i+1, path.at[i].set(next_state))

    key, next_state = sample_step(key, state_action_transition_matrix[:, start_state, cond_action])
    path = path.at[1].set(next_state)
    key, _, path = lax.while_loop(lambda args: args[1] < path.shape[0], loop_body, (key, 2, path))

    return path

@jit_decorator
def sample_batch_paths_cd(key, state_action_transition_matrix, policy, start_states, path, cond_actions):
    """sample a batch of paths under the transition matrix from a batch of start states and conditional actions"""
    key, *process_samples_key = jax.random.split(key, 1+start_states.shape[0])
    process_samples_key = jnp.stack(process_samples_key)
    return key, jax.vmap(sample_path_cd, in_axes=(0, None, None, 0, None, 0))(process_samples_key, state_action_transition_matrix, policy, start_states, path, cond_actions)

"""jitted sampling functions without conditional actions"""

@jit_decorator
def sample_path(key, state_action_transition_matrix, policy, start_state, path):
    """sample a path under the transition matrix from a given start state"""
    path = path.at[0].set(start_state)

    def loop_body(args):
        key, i, path = args
        key, next_state = sample_step(key, jnp.sum(policy[path[i-1]] * state_action_transition_matrix[:, path[i-1]], axis=1)) 
        return (key, i+1, path.at[i].set(next_state))

    key, _, path = lax.while_loop(lambda args: args[1] < path.shape[0], loop_body, (key, 1, path))

    return path

@jit_decorator
def sample_batch_paths(key, state_action_transition_matrix, policy, start_states, path):
    """sample a path under the transition matrix from a batch of start states"""
    key, *process_samples_key = jax.random.split(key, 1+start_states.shape[0])
    process_samples_key = jnp.stack(process_samples_key)
    return key, jax.vmap(sample_path, in_axes=(0, None, None, 0, None))(process_samples_key, state_action_transition_matrix, policy, start_states, path)

class Matrix_Sampler:

    """Matrix sampler: class decorator for sampling from a state action transition matrix with a given policy

    Input attributes:
        state_action_transition_matrix: the state action transition matrix to sample from 
        policy: the policy from which to sample under
    """

    def __init__(self, state_action_transition_matrix, policy):
        self.state_action_transition_matrix = state_action_transition_matrix
        self.policy = policy

    def sample_paths(self, key, start_states, num_paths, path_length, cond_actions=None):
        """sample a bacth of paths from each start state"""
        n_start_states = start_states.shape[0]
        if cond_actions is not None:
            assert start_states.shape[0] == cond_actions.shape[0]
            # flatten the start states
            start_states = jnp.repeat(start_states[:, jnp.newaxis], num_paths, axis=1).ravel()
            cond_actions = jnp.repeat(cond_actions[:, jnp.newaxis], num_paths, axis=1).ravel()
            key, paths = sample_batch_paths_cd(key, self.state_action_transition_matrix, self.policy, start_states, jnp.zeros(path_length, dtype=jnp.int32), cond_actions)
        else:
            # flatten the start states
            start_states = jnp.repeat(start_states[:, jnp.newaxis], num_paths, axis=1).ravel()
            key, paths = sample_batch_paths(key, self.state_action_transition_matrix, self.policy, start_states, jnp.zeros(path_length, dtype=jnp.int32))
        # reshape the batch of paths so they are the expected shape
        reshaped_paths = jnp.reshape(paths, (n_start_states, num_paths, -1))
        return key, reshaped_paths





        
    