import random
import numpy as np

from .base_text_task import TextTask


class TextContinuation(TextTask):
    def __init__(self, config, tokenizer):
        super().__init__(config, tokenizer)
        assert tokenizer.decode([0]) == '<s>'
        assert tokenizer.decode([2]) == '</s>'
        self.task_token = 0
        self.sep_token = 2

    def process_document(self, document):
        prev_source = [self.startoftext_id]

        for segment, last_segment in self.split_document(document):
            source = [self.task_token] + prev_source + [self.sep_token]
            target = [self.task_token] + segment
            prev_source = segment
            yield source, target, last_segment


class TextInfilling(TextTask):
    def __init__(self, config, tokenizer):
        super().__init__(config, tokenizer)
        assert tokenizer.decode([50262]) == '<TextInfill>'
        assert tokenizer.decode([2]) == '</s>'
        # we want to shrink the max_seq_len a little bit
        self.max_seq_len = int(config.processing.max_seq_len * 1.1)
        # reduce the random offset
        self.random_max_seq_len_offset = int(config.processing.random_max_seq_len_offset / 2)
        self.task_token = 50262
        self.sep_token = 2
        self.mask_token = 50264
        # config
        self.mask_token_prob = config.processing.mask_token_prob
        self.mask_possion_lambda = config.processing.mask_possion_lambda

    def process_document(self, document):
        for segment, last_segment in self.split_document(document):
            new_segment = []
            masked_tokens = []

            index = 0
            flag = False
            while index < len(segment):
                if np.random.rand() < self.mask_token_prob and not flag:
                    span_size = np.random.poisson(lam=self.mask_possion_lambda)
                    if span_size == 0:
                        continue
                    flag = True
                    # skip the masked tokens
                    new_segment.append(self.mask_token)
                    masked_tokens.append(segment[index:index + span_size])
                    index += span_size
                else:
                    flag = False
                    new_segment.append(segment[index])
                    index += 1

            source = [self.task_token] + new_segment + [self.sep_token]
            target = [self.task_token]
            for item in masked_tokens:
                target = target + item + [self.sep_token]

            if len(masked_tokens) == 0:
                target = target + [self.sep_token]

            # print("ratio", sum([len(item) for item in masked_tokens]) / len(segment))

            yield source, target, last_segment


class TextRecall(TextTask):
    def __init__(self, config, tokenizer):
        super().__init__(config, tokenizer)
        assert tokenizer.decode([50263]) == '<TextRecall>'
        assert tokenizer.decode([2]) == '</s>'
        self.task_token = 50263
        self.sep_token = 2

    def process_document(self, document):
        prev_source = [self.startoftext_id]
        source = [self.task_token] + prev_source + [self.sep_token]
        target = [self.task_token] + prev_source
        yield source, target, False

        for segment, last_segment in self.split_document(document):
            source = [self.task_token] + segment + [self.sep_token]
            target = [self.task_token] + prev_source
            prev_source = segment
            yield source, target, last_segment


class MultiTextTask(TextTask):
    def __init__(self, config, tokenizer):
        super().__init__(config, tokenizer)

        self.tasks = [eval(task)(config, tokenizer) for task in config.processing.multitasks]
        self.tasks_probs = config.processing.multitasks_probs

    def reset(self):
        for task in self.tasks:
            task.reset()

    def process_document(self, document):
        task_idx = np.random.choice(len(self.tasks), p=self.tasks_probs)

        for item in self.tasks[task_idx].process_document(document):
            yield item