import jax
import numpy as np
from typing import Callable, Iterable, Optional
Trajectory = np.ndarray        # shape (T, D)
RNG = jax.Array
Transform = Callable[[Trajectory, RNG], Trajectory]

def compose(transforms: Iterable[Transform]) -> Transform:
    transforms = list(transforms)

    def _composed(
        traj: Trajectory,
        rng: RNG,
    ) -> Trajectory:
        for t in transforms:
            traj = t(traj, rng)
        return traj

    return _composed

def add_noise(noise_level: float) -> Transform:

    def _add_noise(
        traj: Trajectory,
        rng: RNG,
    ) -> Trajectory:
        if noise_level <= 0.0:
            return traj

        # Per-feature standard deviation over time
        std = traj.std(axis=0, keepdims=True)

        noise = noise_level * std * jax.random.normal(rng, shape=traj.shape)
        return traj + noise

    return _add_noise

def crop(window_size: Optional[int]) -> Transform:
    def _crop(
        traj: Trajectory,
        rng: RNG,
    ) -> Trajectory:
        if window_size is None:
            return traj
        return traj[:window_size]

    return _crop


def normalize(
    *,
    mean: Optional[np.ndarray] = None,
    std: Optional[np.ndarray] = None,
    eps: float = 1e-8,
) -> Transform:
    """
    Normalize a trajectory.

    Two modes:
    - Global normalization: mean and std are provided
    - Per-trajectory normalization: mean/std computed from the trajectory
    """

    use_global = mean is not None and std is not None

    if use_global:
        mean = np.asarray(mean)
        std = np.asarray(std)

    def _normalize(
        traj: Trajectory,
        rng: RNG,
    ) -> Trajectory:
        if use_global:
            return (traj - mean) / (std + eps)

        # Per-trajectory normalization (fallback)
        m = traj.mean(axis=0, keepdims=True)
        s = traj.std(axis=0, keepdims=True)
        return (traj - m) / (s + eps)

    return _normalize
