"""
traj_transforms.py

Contains trajectory transforms used in the orca data pipeline. Trajectory transforms operate on a dictionary
that represents a single trajectory, meaning each tensor has the same leading dimension (the trajectory length).
"""

import logging
from typing import Dict, Optional, Union

import tensorflow as tf


def chunk_act_obs(traj: Dict, window_size: int, future_action_window_size: int = 0, dataset_statistics: Optional[Union[dict, str]] = None) -> Dict:
    """
    Chunks actions and observations into the given window_size.

    "observation" keys are given a new axis (at index 1) of size `window_size` containing `window_size - 1`
    observations from the past and the current observation. "action" is given a new axis (at index 1) of size
    `window_size + future_action_window_size` containing `window_size - 1` actions from the past, the current
    action, and `future_action_window_size` actions from the future. "pad_mask" is added to "observation" and
    indicates whether an observation should be considered padding (i.e. if it had come from a timestep
    before the start of the trajectory).
    """

    traj_len = tf.shape(traj["action"])[0]
    action_dim = traj["action"].shape[-1]
    
    # chunk_indices = tf.broadcast_to(tf.range(-window_size + 1, 1), [traj_len, window_size]) + tf.broadcast_to(
    #     tf.range(traj_len)[:, None], [traj_len, window_size]
    # )
    chunk_indices = tf.broadcast_to(tf.range(-window_size + 1, 1), [traj_len, window_size]) + tf.broadcast_to(
        tf.range(traj_len)[:, None], [traj_len, window_size]
    )

    action_chunk_indices = tf.broadcast_to(
        tf.range(-window_size + 1, 1 + future_action_window_size),
        [traj_len, window_size + future_action_window_size],
    ) + tf.broadcast_to(
        tf.range(traj_len)[:, None],
        [traj_len, window_size + future_action_window_size],
    )

    floored_chunk_indices = tf.maximum(chunk_indices, 0)

    if "timestep" in traj["task"]:
        goal_timestep = traj["task"]["timestep"]
    else:
        goal_timestep = tf.fill([traj_len], traj_len - 1)

    floored_action_chunk_indices = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None])

    traj["observation"] = tf.nest.map_structure(lambda x: tf.gather(x, floored_chunk_indices), traj["observation"])
    traj["action"] = tf.gather(traj["action"], floored_action_chunk_indices)

    # indicates whether an entire observation is padding
    traj["observation"]["pad_mask"] = chunk_indices >= 0

    # if no absolute_action_mask was provided, assume all actions are relative
    if "absolute_action_mask" not in traj and future_action_window_size > 0:
        logging.warning(
            "future_action_window_size > 0 but no absolute_action_mask was provided. "
            "Assuming all actions are relative for the purpose of making neutral actions."
        )
    absolute_action_mask = traj.get("absolute_action_mask", tf.zeros([traj_len, action_dim], dtype=tf.bool))
    

    # Note that the neutral_acitons is not zero, but should be zero after unnormalization instead.
    # hard code NormalizationType.BOUNDS_Q99 for neutral_actions now.
    low = dataset_statistics["action"]["q01"]
    high = dataset_statistics["action"]["q99"]
    norm_zero_action = 2 * (0 - low) / (high - low + 1e-8) - 1
    expanded_norm_zero_action = tf.broadcast_to(norm_zero_action, tf.shape(traj["action"]))  
    expanded_norm_zero_action = tf.cast(expanded_norm_zero_action, dtype=traj["action"].dtype)
    
    neutral_actions = tf.where(absolute_action_mask[:, None, :], traj["action"], expanded_norm_zero_action)
    
    # neutral_actions = tf.where(
    #     absolute_action_mask[:, None, :],
    #     traj["action"],  # absolute actions are repeated (already done during chunking)
    #     tf.zeros_like(traj["action"]),  # relative actions are zeroed
    # )

    # actions past the goal timestep or before the start timestep should become neutral
    action_past_goal = action_chunk_indices > goal_timestep[:, None]
    action_before_start = action_chunk_indices < 0
    traj["action"] = tf.where(action_past_goal[:, :, None], neutral_actions, traj["action"])
    traj["action"] = tf.where(action_before_start[:, :, None], neutral_actions, traj["action"])
    floored_action_chunk_indices_for_masking = tf.minimum(tf.maximum(action_chunk_indices, 0), goal_timestep[:, None]+5)
    traj["action_mask"] = action_chunk_indices == floored_action_chunk_indices_for_masking
    
    return traj


def subsample(traj: Dict, subsample_length: int) -> Dict:
    """Subsamples trajectories to the given length."""
    traj_len = tf.shape(traj["action"])[0]
    if traj_len > subsample_length:
        indices = tf.random.shuffle(tf.range(traj_len))[:subsample_length]
        traj = tf.nest.map_structure(lambda x: tf.gather(x, indices), traj)

    return traj


def add_pad_mask_dict(traj: Dict) -> Dict:
    """
    Adds a dictionary indicating which elements of the observation/task should be treated as padding.
        =>> traj["observation"|"task"]["pad_mask_dict"] = {k: traj["observation"|"task"][k] is not padding}
    """
    traj_len = tf.shape(traj["action"])[0]

    for key in ["observation", "task"]:
        pad_mask_dict = {}
        for subkey in traj[key]:
            # Handles "language_instruction", "image_*", and "depth_*"
            if traj[key][subkey].dtype == tf.string:
                pad_mask_dict[subkey] = tf.strings.length(traj[key][subkey]) != 0

            # All other keys should not be treated as padding
            else:
                pad_mask_dict[subkey] = tf.ones([traj_len], dtype=tf.bool)

        traj[key]["pad_mask_dict"] = pad_mask_dict

    return traj
