import jax
import jax.numpy as jnp
from typing import Callable

def anchored_rollout(stepper_apply: Callable[..., jax.Array], anchor_after: int
) -> Callable[..., jax.Array]:
    """
    Teacher forcing:
        û_0 = u0
        û_{t+1} = f_θ(û_t) if t % anchor_after != 0
                 = u_{t}   if t % anchor_after == 0

    Returns:
        rollout_fn[Callable]: Function that performs the anchored rollout
    """
    def rollout_fn(θ: jax.Array, u: jax.Array) -> jax.Array:
        B, T, d = u.shape
        u0 = u[:, 0, :]

        t_idx = jnp.arange(T - 1)
        if anchor_after <= 0:
            anchor_mask = jnp.zeros((T - 1,), dtype=bool)
        else:
            # Anchor periodically as described in the docstring: t % anchor_after == 0.
            anchor_mask = (t_idx % anchor_after) == 0

        def step(carry: jax.Array, t: int) -> tuple[tuple[jax.Array, jax.Array], jax.Array]:
            u_prev, θ = carry

            def use_anchor(_):
                return u[:, t+1, :]
            
            def use_model(_):
                u_pred = stepper_apply(θ, u_prev[:, None, :])
                return u_pred[:, 0, :]
            
            u_next = jax.lax.cond(
                anchor_mask[t],
                use_anchor,
                use_model,
                operand=None
            )
            return (u_next, θ), u_next
        
        _, traj = jax.lax.scan(step, (u0, θ), jnp.arange(T - 1))
        traj = jnp.swapaxes(traj, 0, 1)
        return jnp.concatenate([u0[:, None, :], traj], axis=1)

    return rollout_fn

def full_rollout(stepper_apply: Callable[..., jax.Array]) -> Callable[..., jax.Array]:
    """
    Full rollout:
        û_0 = u0
        û_{t+1} = f_θ(û_t)

    Returns:
        rollout_fn[Callable]: Function that performs the full rollout
    """
    def rollout_fn(θ: jax.Array, u: jax.Array, ot_horizon:int=200) -> jax.Array:
        B, T, d = u.shape
        u0 = u[:, 0, :]

        def step(carry, _)-> tuple[tuple[jax.Array, jax.Array], jax.Array]:
            u_prev, θ = carry
            u_next = stepper_apply(θ, u_prev[:, None, :])
            u_next = u_next[:, 0, :]
            return (u_next, θ), u_next

        _, traj = jax.lax.scan(step, (u0, θ), None, length=T - 1)
        return jnp.concatenate([u0[:, None, :], traj.swapaxes(0, 1)], axis=1)

    return rollout_fn
