import torch

def unpermute_frame(frame, perm):
    if frame.ndim == 2:
        frame = frame.unsqueeze(0)
    assert frame.ndim == 3, 'frame should have dimension [batchsize n_points x_dim]'
    batch_indices = torch.arange(frame.shape[0]).unsqueeze(-1).expand(-1, frame.shape[1])
    unsort_idcs = torch.argsort(perm, dim=-1)
    return frame[batch_indices, unsort_idcs]


def perm_w_start_idcs(n, start_idcs):
    if start_idcs is None or len(start_idcs) == 0:
        return torch.randperm(n)

    # Create a mask for the indices that need to be first
    mask = torch.zeros(n, dtype=torch.bool)
    mask[start_idcs] = True

    perm = torch.randperm(n)

    ## Separate the permutation into two parts: start_idcs and the rest
    # mask[perm] is the permuted mask
    # so perm[permuted_mask] returns the mask indices in a permuted fashion (so the order is lost)
    perm_start = perm[mask[perm]]
    perm_rest = perm[~mask[perm]]

    perm = torch.cat([perm_start, perm_rest])
    return perm