import torch
from torch.utils.data.dataloader import default_collate


def build_custom_collate_fn(jagged_first: int):
    """
    Returns a collate function that treats the first `jagged_first` elements as jagged tensors
    and collates the rest using PyTorch's default method.
    """

    def custom_collate_fn(batch):
        jagged = [
            torch.nested.nested_tensor(
                [sample[k] for sample in batch], layout=torch.jagged
            )
            # [sample[k] for sample in batch]
            for k in range(jagged_first)
        ]
        collated = [sample[jagged_first:] for sample in batch]
        collated = default_collate(collated)
        return *jagged, *collated

    return custom_collate_fn


def collate_fn_alone_jagged(batch):
    x = torch.nested.nested_tensor(batch, layout=torch.jagged)
    return x
