from jax import Array
import jax.lax as lax
import jax.numpy as np
import jax.random as random


def simulate(s: Array, a: Array, reward_threshold: float) -> Array:
  step = 0.25
  snew = lax.clamp(np.array(0.), s+step, np.array(2.))

  done = np.where(s >= reward_threshold, np.array(1.), np.array(0.))
  r = np.where(np.logical_and(snew >= reward_threshold, 1-done), np.array(1.), np.array(0.))

  return (s, a, r, snew, done)

def get_dataset(key: Array, num_samples: int, data_range: tuple[float, float], reward_threshold: float) -> Array:
  s = random.uniform(key, (num_samples, 1), dtype=np.float32, minval=data_range[0], maxval=data_range[1])
  a = np.zeros((num_samples, 1))

  D = simulate(s, a, reward_threshold)
  return D

def get_dataset_linspaced(num_samples: int, data_range: tuple[float, float], reward_threshold: float) -> Array:
  s = np.linspace(np.array([data_range[0]]), np.array([data_range[1]]), num=num_samples)
  D = simulate(s, np.repeat(np.array([[1]], dtype=int), num_samples, axis=0), reward_threshold)
  return D
