import torch
from typing import Union

from torch.utils.data import Dataset, default_collate


def smart_collate_fn(batch: list[tuple[any, ...]]) -> tuple[Union[torch.Tensor, None, list[dict]], ...]:
    """
    Smart collate function which:
    - For normal tensors (arrays), uses default collate. Returning a tensor with batch dimension.
    - If there is a None in the batch, returns a None
    - Does not collate the dictionaries in the batch, but returns them as is a list of dictionaries.
    """
    separated_items = [[] for _ in range(len(batch[0]))] 
    for item in batch:
        assert len(item) == len(separated_items), "All items in the batch must have the same length"
        for i, elem in enumerate(item):
            separated_items[i].append(elem)

    results = []
    for collection in separated_items:
        if any(item is None for item in collection):
            results.append(None)
        elif isinstance(collection[0], dict):
            results.append(collection)
        else:
            results.append(default_collate(collection))

    return tuple(results)