import torch


def collate_fn(
    inputs: list[dict],
    feature_extractor,
    feature_extractor_kwargs: None | dict = None,
    label_column: str = "category",
    *args,
    **kwargs,
) -> dict[str, torch.Tensor]:
    if feature_extractor_kwargs is None:
        feature_extractor_kwargs = {
            "sampling_rate": 16000,
            "do_normalize": True,
            "padding": True,
            "return_attention_mask": True,
            "return_tensors": "pt",
        }
    audios = [line["audio"]["array"] for line in inputs]
    batch = feature_extractor(
        audios,
        sampling_rate=16000,
        do_normalize=True,
        padding=True,
        return_attention_mask=True,
        return_tensors="pt",
    )
    labels = [dico.pop(label_column) for dico in inputs]
    batch["labels"] = torch.LongTensor(labels)
    return batch
