import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from .jax_helpers import *

"""helper functions"""

@jit_decorator
def sample_step(key, successors, probs, policy):
    """jitted random choice function"""
    key, subkey = jax.random.split(key)
    action = jax.random.choice(subkey, policy.shape[0], p=policy)
    key, subkey = jax.random.split(key)
    return key, jax.random.choice(subkey, successors, p=probs[:, action])

@jit_decorator
def sample_step_cd(key, successors, probs):
    """jitted random choice function"""
    key, subkey = jax.random.split(key)
    return key, jax.random.choice(subkey, successors, p=probs)

"""jitted sampling functions with conditional actions"""

@jit_decorator
def sample_path_cd(key, successor_state_matrix, probabilities, policy, start_state, path, cond_action):
    """sample a path under the transition matrix from a given start state and conditional action"""
    key, next_state = sample_step_cd(key, successor_state_matrix[:, start_state], probabilities[:, start_state, cond_action])
    path = path.at[0].set(next_state)
    
    def loop_body(args):
        key, i, path = args
        curr_state = path[i-1]
        key, next_state = sample_step(key, successor_state_matrix[:, curr_state], probabilities[:, curr_state, :], policy[curr_state, :]) 
        return (key, i+1, path.at[i].set(next_state))

    key, _, path = lax.while_loop(lambda args: args[1] < path.shape[0], loop_body, (key, 1, path))

    return path

@jit_decorator
def sample_batch_paths_cd(key, successor_state_matrix, probabilities, policy, start_states, path, cond_actions):
    """sample a batch of paths under the transition matrix from a batch of start states and conditional actions"""
    key, *process_samples_key = jax.random.split(key, 1+start_states.shape[0])
    process_samples_key = jnp.stack(process_samples_key)
    return key, jax.vmap(sample_path_cd, in_axes=(0, None, None, None, 0, None, 0))(process_samples_key, successor_state_matrix, probabilities, policy, start_states, path, cond_actions)

"""jitted sampling functions without conditional actions"""

@jit_decorator
def sample_path(key, successor_state_matrix, probabilities, policy, start_state, path):
    """sample a path under the transition matrix from a given start state"""
    key, next_state = sample_step(key, successor_state_matrix[:, start_state], probabilities[:, start_state, :], policy[start_state, :])
    path = path.at[0].set(next_state)

    def loop_body(args):
        key, i, path = args
        curr_state = path[i-1]
        key, next_state = sample_step(key, successor_state_matrix[:, curr_state], probabilities[:, curr_state, :], policy[curr_state, :])
        return (key, i+1, path.at[i].set(next_state))

    key, _, path = lax.while_loop(lambda args: args[1] < path.shape[0], loop_body, (key, 1, path))

    return path

@jit_decorator
def sample_batch_paths(key, successor_state_matrix, probabilities, policy, start_states, path):
    """sample a path under the transition matrix from a batch of start states"""
    key, *process_samples_key = jax.random.split(key, 1+start_states.shape[0])
    process_samples_key = jnp.stack(process_samples_key)
    return key, jax.vmap(sample_path, in_axes=(0, None, None, None, 0, None))(process_samples_key, successor_state_matrix, probabilities, policy, start_states, path)

class Successor_Sampler:

    def __init__(self, successor_state_matrix, probabilities, policy):
        self.successor_state_matrix = successor_state_matrix
        self.probabilities = probabilities
        self.policy = policy

    def sample_paths(self, key, start_states, num_paths, path_length, cond_actions=None):

        n_start_states = start_states.shape[0]

        if cond_actions is not None:
            assert start_states.shape[0] == cond_actions.shape[0]
            start_states = jnp.repeat(start_states[:, jnp.newaxis], num_paths, axis=1).ravel()
            cond_actions = jnp.repeat(cond_actions[:, jnp.newaxis], num_paths, axis=1).ravel()
            key, paths = sample_batch_paths_cd(key, self.successor_state_matrix, self.probabilities, self.policy, start_states, jnp.zeros(path_length, dtype=jnp.int32), cond_actions)
        else:
            start_states = jnp.repeat(start_states[:, jnp.newaxis], num_paths, axis=1).ravel()
            key, paths = sample_batch_paths(key, self.successor_state_matrix, self.probabilities, self.policy, start_states, jnp.zeros(path_length, dtype=jnp.int32))

        reshaped_paths = jnp.reshape(paths, (n_start_states, num_paths, -1))
        return key, reshaped_paths