import chex
import jax
import jax.numpy as jnp
import jax.random as random


def sample_within_threshold(predicted, demonstration, threshold=0.5):
    """
    Process trajectory into points up until some threshold from (0, 0).
    Once processed to within this threshold, uniformly sample the demonstration trajectory to match dimensions.

    Args:
        predicted (np.ndarray): Predicted trajectory of shape [N, 2].
        demonstration (np.ndarray): Demonstration trajectory of shape [M, 2].
        threshold (float): Threshold distance from (0, 0) to filter predicted points.

    Returns:
        sampled_predicted (np.ndarray): Filtered predicted trajectory within the threshold.
        sampled_demonstration (np.ndarray): Resampled demonstration trajectory to match the length of sampled_predicted.
    """
    # Calculate the distance of each point in predicted from (0,0)
    distances = jnp.linalg.norm(predicted, axis=0)

    # ensure all points after first within threshold are removed
    indices_less = jnp.where(distances < threshold)[0]
    if len(indices_less) > 0:
        threshold_index = indices_less[0]
        indices = jnp.arange(0, threshold_index)
    else:
        indices = jnp.arange(0, predicted.shape[-1])

    sampled_predicted = jnp.squeeze(predicted[:, indices])

    if sampled_predicted.shape[-1] == 0:
        return None, None

    if sampled_predicted.shape[-1] < demonstration.shape[-1]:
        step_size = demonstration.shape[-1] // sampled_predicted.shape[-1]
        if step_size > 1:
            sampled_indices = jnp.arange(0, demonstration.shape[-1], step_size)[: sampled_predicted.shape[-1]]
            sampled_demonstration = demonstration[:, sampled_indices]
        else:
            # randomly drop points from within array in order to match sizes
            num_points_to_drop = demonstration.shape[-1] - sampled_predicted.shape[-1]
            indices_to_drop = random.choice(
                random.key(42), jnp.arange(1, demonstration.shape[-1] - 1), (num_points_to_drop,), replace=False
            )
            sampled_demonstration = jnp.delete(demonstration, indices_to_drop, axis=1)
    else:
        raise NotImplementedError

    return sampled_predicted, sampled_demonstration


@jax.jit
def deque_append(
    memory: jax.Array, front: int, rear: int, n_elements: int, element: chex.Array, is_first: bool = False
):
    """Add an element to the rear of the deque."""

    # If is_first is True, fill the memory with the appended element
    n_memory = jax.lax.cond(
        is_first,
        lambda _: jnp.tile(element[None, :], (memory.shape[0], 1)),
        lambda _: memory.at[0, :].set(element),
        operand=None,
    )

    idx = jnp.hstack((jnp.arange(1, memory.shape[0]), jnp.asarray(0)))
    n_memory = n_memory[idx, :]
    rear = jnp.clip(rear + 1, 0, memory.shape[0] - 1)
    num_elements = jnp.minimum(n_elements + 1, memory.shape[0])

    return n_memory, front, rear, num_elements


def evaluate_mean_absolute_jerk(traj):
    """
    Evaluate mean absolute jerk of a trajectory.
    """
    vel = jnp.diff(traj)
    acc = jnp.diff(vel)
    jerk = jnp.diff(acc)
    return jnp.mean(jnp.abs(jerk) ** 2)
