import torch
from transformers.models.whisper.feature_extraction_whisper import (
    WhisperFeatureExtractor,
)


def extract_features(
    examples: dict[str, list],
    feature_extractor: WhisperFeatureExtractor,
    feature_extractor_kwargs: None | dict = None,
):
    if feature_extractor_kwargs is None:
        feature_extractor_kwargs = {"device": "cuda:0", "sampling_rate": 16000}
    features = {}
    for audio in examples["audio"]:
        line_features = feature_extractor(audio["array"], **feature_extractor_kwargs)
        for key, value in line_features.items():
            if key not in features:
                features[key] = []
            features[key].append(value)
    for key, value in features.items():
        examples[key] = value
    return examples


def collate_fn(inputs: list[dict], label_column: str = "category", *args, **kwargs):
    batch = {}
    batch["input_features"] = torch.vstack(
        [torch.Tensor(line["input_features"]) for line in inputs]
    )
    labels = [dico.pop(label_column) for dico in inputs]
    batch["labels"] = torch.LongTensor(labels)
    return batch
