from collections.abc import Callable

import torch

from .thingsvision import ThingsvisionModel, load_thingsvision_model


def load_model(
    source: str,
    model_name: str,
    module_names: str | list[str],
    model_parameters: dict | None = None,
    feature_alignment: str | None = None,
    device: str | torch.device = "cuda",
) -> tuple[ThingsvisionModel, Callable]:
    """Load a model."""
    return load_thingsvision_model(
        model_name=model_name,
        source=source,
        model_parameters=model_parameters,
        device=device,
        module_names=module_names,
        feature_alignment=feature_alignment,
    )
