import torch
from torch.nn.utils.rnn import pad_sequence


def collate_fn(inputs: list[dict], label_column: str = "category", *args, **kwargs):
    batch = {}
    features = [torch.Tensor(line["input_features"]).squeeze(0) for line in inputs]
    batch["input_features"] = pad_sequence(features, batch_first=True)
    if "attention_mask" in inputs[0]:
        batch["attention_mask"] = pad_sequence(
            [torch.LongTensor(line["attention_mask"]).squeeze(0) for line in inputs],
            batch_first=True,
        )
    labels = [dico.pop(label_column) for dico in inputs]
    batch["labels"] = torch.LongTensor(labels)
    return batch
