"""pretraining prompt strategies"""
from typing import Generator

from transformers import BatchEncoding

from axolotl.prompt_tokenizers import PromptTokenizingStrategy


class PretrainTokenizer:
    """basic tokenization class for pretraining"""

    def build_prompt(self, prompt) -> Generator[str, None, None]:
        yield prompt


class PretrainTokenizationStrategy(PromptTokenizingStrategy):
    """handles tokenization for pretraining with strides"""

    @property
    def supports_batched(self):
        return True

    def __init__(self, *args, max_length=None, text_column="text", **kwargs):
        super().__init__(*args, **kwargs)
        if max_length:
            self.max_length = max_length
        self.text_column = text_column

    def _tokenize(
        self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
    ) -> BatchEncoding:
        res = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_length - 1,
            add_special_tokens=True,
            return_overflowing_tokens=True,
            stride=256,
        )
        res["input_ids"] = [
            seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
        ]
        res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]

        return res

    def tokenize_prompt(self, prompt):
        return self._tokenize(prompt[self.text_column])


def load(tokenizer, cfg):
    strat = PretrainTokenizationStrategy(
        PretrainTokenizer(),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
        text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
        max_length=cfg.sequence_len * 64,
    )
    return strat
