import jax
import jax.numpy as jnp
import chex
import einops
from typing import Optional

from multinav.utils.utils import get_batch_pattern

@jax.jit
def rollout_trajectory(actions: jnp.ndarray, initial_pose: jnp.ndarray):
    chex.assert_shape(actions, [None, 4])
    chex.assert_shape(initial_pose, [4])

    # Unnormalize actions
    actions = actions + jnp.array([1, 0, 1, 0])

    trajectory = []
    directions = []
    transform_matrix = jnp.array(
        [
            [initial_pose[2], -initial_pose[3], initial_pose[0]],
            [initial_pose[3], initial_pose[2], initial_pose[1]],
            [0, 0, 1],
        ]
    )

    for action in actions:
        action_dheading = action[2:] / (jnp.linalg.norm(action[2:]) + 1e-3)
        cos_heading = action_dheading[0]
        sin_heading = action_dheading[1]

        transform_matrix = jnp.matmul(
            transform_matrix,
            jnp.array(
                [
                    [cos_heading, -sin_heading, action[0]],
                    [sin_heading, cos_heading, action[1]],
                    [0, 0, 1],
                ]
            ),
        )
        trajectory.append(transform_matrix[:2, 2])
        directions.append(transform_matrix[:2, 0])

    return jnp.stack(trajectory), jnp.stack(directions)


@jax.jit
def rollout_trajectories(actions: jnp.ndarray, *, initial_pose: Optional[jnp.ndarray] = None):
    *batch_dims, predict_horizon, action_dim = actions.shape
    assert action_dim == 4, "Expected actions to be of shape [batch_dims..., predict_horizon, 4]"

    batch_str, batch_dict = get_batch_pattern(*batch_dims)

    if initial_pose is None:
        initial_pose = jnp.broadcast_to(jnp.array([0, 0, 1, 0]), batch_dims + [4])

    chex.assert_shape(initial_pose, batch_dims + [4])

    actions = einops.rearrange(actions, f"{batch_str} p a -> ({batch_str}) p a", **batch_dict)
    initial_pose = einops.rearrange(initial_pose, f"{batch_str} a -> ({batch_str}) a", **batch_dict)
    rollout_pos, rollout_dir = jax.vmap(rollout_trajectory)(actions, initial_pose)
    rollout = jnp.concatenate([rollout_pos, rollout_dir], axis=-1)
    rollout = einops.rearrange(rollout, f"({batch_str}) p a -> {batch_str} p a", **batch_dict)
    return rollout
