import torch
from typing import Optional, Callable, Tuple


torch.set_grad_enabled(False)

class BrainAlignLanguageModelBase:
    """
    A base class for extracting language models representations.
    """

    hugging_face_model_id: str
    model_type: str

    def __init__(self):
        raise NotImplemented()

    def __str__(self):
        raise NotImplemented()
    
    def get_layer_representations(
        self,
        token_ids: torch.LongTensor
    ) -> torch.Tensor:
        """Returns the hidden representations of `tokens` in the model at every layer."""
        raise NotImplemented()
