import random


class TextTask:
    def __init__(self, config, tokenizer):
        self.tokenizer = tokenizer

        # config
        self.random_max_seq_len_offset = config.processing.random_max_seq_len_offset
        self.max_seq_len = config.processing.max_seq_len

        self.startoftext_id = config.processing.startoftext_id
        self.endoftext_id = config.processing.endoftext_id
        self.buffer = []

    def reset(self):
        self.buffer = []

    def get_max_seq_len(self):
        dynamic_max_seq_len = self.max_seq_len + random.randint(
            -self.random_max_seq_len_offset, self.random_max_seq_len_offset
        )
        return dynamic_max_seq_len

    def split_document(self, document):
        document = self.tokenizer.encode(document, add_special_tokens=False).ids
        # warp the document with start and end tokens
        document = [self.startoftext_id] + document + [self.endoftext_id]

        dynamic_max_seq_len = self.get_max_seq_len()

        for idx in range(0, len(document)):
            token_id = document[idx]
            self.buffer.append(token_id)

            if token_id == self.endoftext_id or len(self.buffer) == dynamic_max_seq_len:
                last_segment = (token_id == self.endoftext_id)
                # source, target, last_segment
                yield self.buffer, last_segment
                # clear the buffer
                self.buffer = []
                # reset the dynamic seq len
                dynamic_max_seq_len = self.get_max_seq_len()

        self.buffer = []

    def process_document(self, document):
        for item in self.split_document(document):
            yield item