from functools import partial
from typing import Any, Dict, List, Sequence, Union
import chex
import einops
import tensorflow as tf
from dlimp.transforms.common import selective_tree_map
import dlimp.dataset
from tensorflow_graphics.geometry.transformation.euler import from_quaternion

def add_concatenated_actions(
    traj: Dict[str, Any],
    window_size: int,
) -> Dict[str, Any]:
    traj_len = tf.shape(traj["action_angle"])[0]

    base_indices = tf.range(traj_len)
    frame_indices = tf.range(window_size)

    indices = base_indices[:, None] + frame_indices[None, :]
    clipped_indices = tf.minimum(indices, traj_len - 1)

    action = traj["action_angle"]
    action = tf.concat(
        [
            action[..., :2],
            tf.math.cos(action[..., 2:3]),
            tf.math.sin(action[..., 2:3]),
        ],
        axis=-1,
    )

    action_chunked = tf.gather(action, clipped_indices, axis=0)
    traj["action"] = action
    traj["action_chunked"] = action_chunked
    traj["action_mask"] = indices < traj_len
    del traj["action_angle"]

    return traj


def batch_decode_images(
    x: Dict[str, Any], match: Union[str, Sequence[str]] = "image"
) -> Dict[str, Any]:
    """Can operate on nested dicts. Decodes any leaves that have `match` anywhere in their path."""
    if isinstance(match, str):
        match = [match]

    return selective_tree_map(
        x,
        lambda keypath, value: any([s in keypath for s in match])
        and value.dtype == tf.string,
        partial(
            tf.vectorized_map, partial(tf.io.decode_image, expand_animations=False)
        ),
    )

def correct_action_keys(
        traj: Dict[str, Any],
) -> Dict[str, Any]:
    # make sure our trajectory is of the right format to be processed
    # must have "action_angle" as a key with x, y, yaw

    if "action_angle" not in traj.keys():
        traj["action_angle"] = tf.concat([traj["action"][:, :2], 
                                          tf.reshape(traj["action"][:, -1], (-1, 1)) ], axis = 1)

    if "reward" not in traj.keys():
        traj["reward"] = tf.zeros_like(traj["_len"])  # ALL 0s

    if "yaw" not in traj["observation"].keys():
        traj["observation"]["yaw"] = tf.expand_dims(from_quaternion(traj["observation"]["orientation"])[:, 2], axis = 1)

    if traj["observation"]["position"].shape[-1] != 2:
        traj["observation"]["position"] = traj["observation"]["position"][:, :2]

    if "language_embedding" not in traj.keys():
        traj["language_embedding"] = tf.tile(tf.zeros_like(traj["observation"]["position"]), [1, int(512 / 2)])

    return traj 

def normalize_images(
    x: Dict[str, Any],
    match: Union[str, Sequence[str]] = "image",
) -> Dict[str, Any]:
    """
    Can operate on nested dicts. Normalizes any leaves that have `match` anywhere in their path.
    Takes uint8 images as input and returns float images in range [0, 1].
    """
    if isinstance(match, str):
        match = [match]

    def normalize_image(image: tf.Tensor) -> tf.Tensor:
        """
        Normalize the image to be between 0 and 1.
        """
        IMAGENET_MEAN = tf.constant([0.485, 0.456, 0.406])
        IMAGENET_STD = tf.constant([0.229, 0.224, 0.225])

        return (tf.cast(image, tf.float32) / 255.0 - IMAGENET_MEAN) / IMAGENET_STD

    return selective_tree_map(
        x,
        lambda keypath, value: any([s in keypath for s in match])
        and value.dtype == tf.uint8,
        normalize_image,
    )


def subsample_multiple_trajectories(
    trajectory,
    size: int = 8,
    num_trajectories: int = 1,
):
    traj_len = trajectory["_len"][0]

    idx = tf.random.uniform(
        shape=(num_trajectories,), minval=0, maxval=traj_len - size, dtype=tf.int32
    )
    remaining_len = traj_len - idx
    goal_idx = (
        (
            tf.cast(
                # tf.random.gamma(shape=(num_trajectories,), alpha=6, beta=0.35),
                tf.random.gamma(shape=(num_trajectories,), alpha=1, beta=1 / 20),
                tf.int32,
            )
            + tf.random.uniform(shape=(num_trajectories,), maxval=size, dtype=tf.int32)
        )
        # tf.random.uniform(shape=(num_trajectories,), minval=0, maxval=28, dtype=tf.int32)
        % remaining_len
        + idx
    )

    def subsample(x: tf.Tensor):
        sub_trajectory = tf.gather(x, idx[:, None] + tf.range(size)[None, :], axis=0)
        goal = tf.gather(x, goal_idx[:, None], axis=0)
        return tf.concat([sub_trajectory, goal], axis=1)

    result = tf.nest.map_structure(subsample, trajectory)

    goal_frame_index = result["_frame_index"][:, -1:]
    result["reached_goal"] = result["_frame_index"] >= goal_frame_index
    result["time_to_goal"] = goal_frame_index - result["_frame_index"]
    crashed = result["is_last"] & result["end_is_terminal"]
    result["is_terminal"] = result["reached_goal"] | crashed

    discount = 0.97
    drive_reward = tf.cast(result["reached_goal"], tf.float32) - 1
    crash_reward = -1 / (1 - discount) * tf.cast(crashed, tf.float32)
    result["reward"] = drive_reward + crash_reward

    result["_len"] = tf.fill(tf.shape(result["_len"]), size + 1)

    result["mask"] = tf.ones_like(result["reward"], dtype=tf.bool)

    return result


def subsample_multiple_trajectories_dynamic(
    trajectory,
    size=8,
) -> dlimp.DLataset:
    num_trajectories = tf.shape(trajectory["reward"])[0] // size
    return dlimp.dataset._wrap(dlimp.DLataset.from_tensor_slices, is_flattened=False)(
        subsample_multiple_trajectories(
            trajectory,
            num_trajectories=num_trajectories,
            size=size,
        )
    )


def subsample_trajectory(trajectory, size=8):
    return tf.nest.map_structure(
        lambda x: tf.squeeze(x, axis=0),
        subsample_multiple_trajectories(
            trajectory,
            size=size,
        ),
    )


def split_goal_obs(trajectory):
    seq_len = trajectory["_len"][0] - 1
    goal = tf.nest.map_structure(lambda x: x[-1], trajectory["observation"])

    trajectory = tf.nest.map_structure(lambda x: x[:-1], trajectory)
    trajectory["goal"] = tf.nest.map_structure(
        lambda x: tf.repeat(x[None], seq_len, axis=0), goal
    )

    return trajectory


def remove_strings(trajectory):
    if "language_instruction" in trajectory.keys():
        del trajectory["language_instruction"]
    if "traj_metadata" in trajectory.keys() and "episode_metadata" in trajectory["traj_metadata"].keys() and "file_path" in trajectory["traj_metadata"]["episode_metadata"].keys():
        del trajectory["traj_metadata"]["episode_metadata"]["file_path"]
    return trajectory


def normalize_actions(trajectory):
    for key in ["action", "action_chunked"]:
        action = tf.cast(trajectory[key], tf.float32)

        spacing = tf.cast(trajectory["waypoint_spacing"], tf.float32)
        while spacing.ndim < action.ndim:
            spacing = spacing[..., None]

        trajectory[f"{key}_unnormalized"] = action
        trajectory[key] = tf.concat(
            [
                action[..., :2] / spacing,
                action[..., 2:],
            ],
            axis=-1,
        ) - tf.cast(tf.constant([1, 0, 1, 0]), tf.float32)

    return trajectory


def normalized_pose(trajectory):
    position = trajectory["observation"]["position"]
    yaw = tf.squeeze(trajectory["observation"]["yaw"], axis=-1)
    chex.assert_shape(position, (None, 2))
    chex.assert_shape(yaw, (None,))

    zero_pos = position[0]
    zero_yaw = yaw[0]

    transform = tf.stack(
        [
            tf.stack([tf.math.cos(zero_yaw), -tf.math.sin(zero_yaw)]),
            tf.stack([tf.math.sin(zero_yaw), tf.math.cos(zero_yaw)]),
        ]
    )

    position = tf.linalg.matvec(tf.transpose(transform), position - zero_pos)
    yaw -= zero_yaw

    pose = tf.concat(
        [
            position,
            tf.math.cos(yaw[..., None]),
            tf.math.sin(yaw[..., None]),
        ],
        axis=-1,
    )

    trajectory["observation"].update(
        {
            "pose": pose,
            "position": position,
            "yaw": yaw,
        }
    )

    return trajectory


def sample_negatives(batch, negative_frac: float):
    batch_size = tf.shape(batch["_len"])[0]

    negatives: tf.Tensor = (tf.range(0, batch_size) / batch_size) < negative_frac
    goal_idcs = tf.where(
        negatives,
        tf.random.uniform((batch_size,), 0, batch_size, dtype=tf.int32),
        tf.range(0, batch_size),
    )
    batch["goal"] = tf.nest.map_structure(lambda x: tf.gather(x, goal_idcs, axis=0), batch["goal"])

    replacements = {
        "reward": -1.0,
        "is_terminal": False,
        "reached_goal": False,
        "time_to_goal": 128,
    }

    for k, v in replacements.items():
        # Expand dims from the _end_ until negatives has the same rank as batch data
        negatives_like_data = negatives
        while negatives_like_data.ndim < batch[k].ndim:
            negatives_like_data = tf.expand_dims(negatives, axis=-1)
        batch[k] = tf.where(
            negatives_like_data,
            v,
            batch[k],
        )

    return batch
