import torch
from typing import Tuple


def trajectories_to_transitions(trajectories: torch.Tensor, data: Tuple[torch.Tensor, int, bool]) -> torch.Tensor:
    """Unpacks a tensor of trajectories into a tensor of transitions.

    Args:
        trajectories (torch.Tensor): A tensor of trajectories.
        data (Tuple[torch.Tensor, int, bool]): A tuple containing the mask and data for the conversion.
        batch_first (bool, optional): Whether the first dimension of the trajectories tensor is the batch dimension.
            Defaults to False.
    Returns:
        A tensor of transitions of shape (batch_size, time, *).
    """
    mask, batch_size, batch_first = data

    if not batch_first:
        trajectories, mask = trajectories.transpose(0, 1), mask.transpose(0, 1)

    transitions = trajectories[mask == 1.0].reshape(batch_size, -1, *trajectories.shape[2:])

    return transitions


def transitions_to_trajectories(
    transitions: torch.Tensor, dones: torch.Tensor, batch_first: bool = False
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, int, bool]]:
    """Packs a tensor of transitions into a tensor of trajectories.

    Example:
        >>> transitions = torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
        >>> dones = torch.tensor([[0, 0, 1], [0, 1, 0]])
        >>> transitions_to_trajectories(None, transitions, dones, batch_first=True)
        (tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [0, 0]], [[11, 12], [0, 0], [0, 0]]]), tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]]))

    Args:
        transitions (torch.Tensor): Tensor of transitions of shape (batch_size, time, *).
        dones (torch.Tensor): Tensor of transition terminations of shape (batch_size, time).
        batch_first (bool): Whether the first dimension of the output tensor should be the batch dimension. Defaults to
            False.
    Returns:
        A torch.Tensor of trajectories of shape (time, trajectory_count, *) that is padded with zeros and data for
        reverting the operation. If batch_first is True, the shape of the trajectories is (trajectory_count, time, *).
    """
    batch_size = transitions.shape[0]

    # Count the trajectory lengths by (1) padding  dones with a 1 at the end to indicate the end of the trajectory,
    # (2) stacking up the padded dones in a single column, and (3) counting the number of steps between each done by
    # using the row index.
    padded_dones = dones.clone()
    padded_dones[:, -1] = 1
    stacked_dones = torch.cat((padded_dones.new([-1]), padded_dones.reshape(-1, 1).nonzero()[:, 0]))
    trajectory_lengths = stacked_dones[1:] - stacked_dones[:-1]

    # Compute trajectories by splitting transitions according to previously computed trajectory lengths.
    trajectory_list = torch.split(transitions.flatten(0, 1), trajectory_lengths.int().tolist())
    trajectories = torch.nn.utils.rnn.pad_sequence(trajectory_list, batch_first=batch_first)

    # The mask is generated by computing a 2d matrix of increasing counts in the 2nd dimension and comparing it to the
    # trajectory lengths.
    range = torch.arange(0, trajectory_lengths.max()).repeat(len(trajectory_lengths), 1)
    range = range.cuda(dones.device) if dones.is_cuda else range
    mask = (trajectory_lengths.unsqueeze(1) > range).float()

    if not batch_first:
        mask = mask.T

    return trajectories, (mask, batch_size, batch_first)
