"""Utilities for models."""
from typing import Dict, Optional

import torch
from transformers import PreTrainedModel


# Keys in examples that are typically used. Depending on the model/tokenizer, only
# a subset of these might be typically used.
_TYPICAL_EXAMPLE_KEYS = (
    # Example keys that usually directly provided by the tokenizer.
    'input_ids', 'attention_mask', 'token_type_ids',
    # These are additional keys provided for lm_suffix_mc "models".
    'context_length',
)


def compute_logits(
    model: PreTrainedModel,
    batch: Dict[str, torch.Tensor],
    device: Optional[torch.device] = None,
) -> torch.Tensor:
    batch = {k: v.to(device) for k, v in batch.items() if k in _TYPICAL_EXAMPLE_KEYS}
    return model(**batch).logits
