import torch

def random_split_tensor(input: torch.Tensor, device=None, chunks=2, dim=0):
    """
    Randomly split a tensor batch into n equal chunks.
    Returns a list of random tensor chunks.
    """
    batch_size = input.shape[0]
    indices = torch.randperm(batch_size).to(device)
    index_chunks = torch.tensor_split(indices, chunks)
    subsets = [torch.index_select(input, dim, indices) for indices in index_chunks]
    return subsets