"""Common things and ABCs to support ICL for datasets."""
import abc
from typing import Any, Dict, List

import torch
from transformers import PreTrainedTokenizer

from npeff_torch.models import lm_mcqa

###############################################################################


class IclExampleHelperAbc(abc.ABC):
    """ABC for classes to help make ICL examples."""

    # Must be called by subclasses
    def __init__(
        self, *,
        tokenizer: PreTrainedTokenizer,
        sequence_length: int,
    ):
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length

    @abc.abstractmethod
    def make_text_for_individual_example(self, example: Dict[str, Any], *, include_label: bool) -> str:
        """Makes the text for an individual example.

        Args:
            example: The example to encode.
            include_label: Whether to include the label in the output. Generally should be True
                for context examples and False for the example whose prediction we want to get.
        """
        raise NotImplementedError
    
    @abc.abstractmethod
    def join_examples_text(self, texts: List[str]) -> str:
        """Given texts for examples from `make_text_for_individual_example`, join them into a single text."""
        raise NotImplementedError

    @abc.abstractmethod
    def encode_text(self, text: str) -> Dict[str, torch.Tensor]:
        """Given some full text, create the tensors for the tokenized representation."""
        raise NotImplementedError


###############################################################################

# # Make a class for lm_mcqa models to go from tokenized context + label to context.

# class TokenizedLmMcqaContextMaker:

#     def __init__(
#         self, *,
#         model: lm_mcqa.LmMcqaLogitsComputer,
#         tokenizer: PreTrainedTokenizer,
#     ):
#         self._model = model
#         self._tokenizer = tokenizer

#     pass


