import jax
import jax.numpy as jnp
from typing import Callable

def add_noise_jax(noise_level: float) -> Callable[..., jax.Array]:
    def _add_noise(traj: jnp.ndarray, rng: jax.Array) -> jnp.ndarray:
        if noise_level <= 0.0:
            return traj
        
        std = traj.std(axis=-2, keepdims=True)  # Over time dimension
        noise = noise_level * std * jax.random.normal(rng, shape=traj.shape)
        return traj + noise
    
    return _add_noise

def crop_jax(window_size: int | None = None) -> Callable[..., jax.Array]:
    def _crop(traj: jnp.ndarray, rng: jax.Array) -> jnp.ndarray:
        if window_size is None:
            return traj
        return traj[..., :window_size, :]  # Works for batched input
    
    return _crop

def normalize_jax(mean=None, std=None, eps=1e-8) -> Callable[..., jax.Array]:
    use_global = mean is not None and std is not None
    
    if use_global:
        mean = jnp.asarray(mean)
        std = jnp.asarray(std)
    
    def _normalize(traj: jnp.ndarray, rng: jax.Array) -> jnp.ndarray:
        if use_global:
            return (traj - mean) / (std + eps)
        
        m = traj.mean(axis=-2, keepdims=True)
        s = traj.std(axis=-2, keepdims=True)
        return (traj - m) / (s + eps)
    
    return _normalize