import math
import numpy as np
import torch
from numpy.random import permutation, poisson


class DataCollatorForMLM(object):
    def __init__(self, tokenizer, mlm_probability=0.15):
        self.tokenizer = tokenizer
        self.mlm_probability = mlm_probability

    def __call__(self, examples):
        """ Prepare masked tokens inputs/labels for masked language
        modeling: 80% MASK, 10% random, 10% original. """
        examples = {
            k: torch.stack([x[k] for x in examples])
            for k in examples[0].keys()
        }
        token_ids = examples['token_ids']
        attention_mask = examples['attention_mask']
        labels = token_ids.clone()

        # We sample a few tokens in each sequence for masked-LM training
        # (with probability self.mlm_probability defaults to 0.15 in
        # Bert/RoBERTa)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(
                val, already_has_special_tokens=True)
            for val in labels.tolist()
        ]
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask,
                                                     dtype=torch.bool),
                                        value=0.0)
        if self.tokenizer._pad_token is not None:
            padding_mask = labels.eq(self.tokenizer.pad_token_id)
            probability_matrix.masked_fill_(padding_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with
        # tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(
            labels.shape, 0.8)).bool() & masked_indices
        token_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
            self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = \
            torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & \
            masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer),
                                     labels.shape,
                                     dtype=torch.long)
        token_ids[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input
        # tokens unchanged
        return {
            'token_ids': token_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'example_indices': examples['example_indices']
        }


class DataCollatorForDenoisingReconstrcution(object):
    """Data collator used denoising language modeling task in BART.
    The implementation is based on
    https://github.com/pytorch/fairseq/blob/
    1bba712622b8ae4efb3eb793a8a40da386fe11d0/fairseq/data/denoising_dataset.py.
    The default paramters is based on BART paper
    https://arxiv.org/abs/1910.13461.
    """
    def __init__(self,
                 tokenizer,
                 mask_ratio=0.3,
                 poisson_lambda=3.0,
                 permutate_sentence_ratio=1.0):
        self.tokenizer = tokenizer
        self.mask_ratio = mask_ratio
        self.poisson_lambda = poisson_lambda
        self.permutate_sentence_ratio = permutate_sentence_ratio

    def __call__(self, examples):
        examples = {
            k: torch.stack([x[k] for x in examples])
            for k in examples[0].keys()
        }
        token_ids = examples['token_ids'].numpy()
        attention_mask = examples['attention_mask'].numpy()
        labels = token_ids.copy()

        do_permutate = False
        if self.permutate_sentence_ratio > 0.0:
            permute_sent = self.permutate_sentences(token_ids[:, 1:])
            for i, s in enumerate(permute_sent):
                token_ids[i, 1:] = s
            do_permutate = True

        if self.mask_ratio:
            token_ids, _ = self.add_whole_word_mask(token_ids, do_permutate)
            num_non_padding = np.sum(token_ids != self.tokenizer.pad_token_id,
                                     axis=-1)
            for i in range(len(attention_mask)):
                attention_mask[i][num_non_padding[i]:] = 0

        token_ids = torch.from_numpy(token_ids)
        attention_mask = torch.from_numpy(attention_mask)
        labels = torch.from_numpy(labels)
        return {
            'token_ids': token_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'example_indices': examples['example_indices']
        }

    def permutate_sentences(self, inputs):
        results = inputs.copy()

        for i in range(inputs.shape[0]):
            full_stops = (inputs[i] == self.tokenizer.eoq_token_id) | (
                inputs[i] == self.tokenizer.eos_token_id)
            full_stops = full_stops[None, :]
            sentence_ends = np.argwhere(full_stops[:, 1:] *
                                        ~full_stops[:, :-1])
            if len(sentence_ends) == 0:
                continue

            sentence_ends[:, 1] += 2
            num_sentences = np.unique(sentence_ends[:, 0],
                                      return_counts=True)[1]
            num_to_permute = np.ceil(
                (num_sentences * 2 * self.permutate_sentence_ratio) /
                2.0).astype(int)
            sentence_ends = np.split(
                sentence_ends[:, 1],
                np.unique(sentence_ends[:, 0], return_index=True)[1][1:])

            substitutions = np.random.permutation(
                num_sentences[0])[:num_to_permute[0]]
            ordering = np.arange(0, num_sentences[0])
            ordering[substitutions] = substitutions[np.random.permutation(
                num_to_permute[0])]

            index = 0
            for j in ordering:
                sentence = inputs[i, (
                    sentence_ends[0][j -
                                     1] if j > 0 else 0):sentence_ends[0][j]]
                results[i, index:index + sentence.shape[0]] = sentence
                index += sentence.shape[0]

        num_non_padding = np.sum(results != self.tokenizer.pad_token_id,
                                 axis=-1)
        eos_indices = np.where(results == self.tokenizer.eos_token_id)[1]
        for i, (idx1, idx2) in enumerate(zip(eos_indices, num_non_padding)):
            results[i][idx1] = self.tokenizer.eoq_token_id
            results[i][idx2 - 1] = self.tokenizer.eos_token_id

        return results

    def add_whole_word_mask(self, inputs, do_permutate):
        labels = inputs.copy()
        inputs = inputs.copy()

        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(
                val, already_has_special_tokens=True)
            for val in labels.tolist()
        ]
        special_tokens_mask = np.array(special_tokens_mask, dtype=bool)

        # determine how many tokens we need to mask in total
        is_token = ~(labels == self.tokenizer.pad_token_id) & \
                   ~special_tokens_mask
        num_to_mask = int(
            math.ceil(is_token.astype(float).sum() * self.mask_ratio))
        if num_to_mask == 0:
            return inputs, labels

        # generate a sufficient number of span lengths
        lengths = poisson(lam=self.poisson_lambda, size=(num_to_mask, ))
        while np.cumsum(lengths, 0)[-1] < num_to_mask:
            lengths = np.concatenate([
                lengths,
                poisson(lam=self.poisson_lambda, size=(num_to_mask, ))
            ])

        # remove all spans of length 0
        # Note that BART inserts additional mask tokens where length == 0,
        # which we do not implement for now as it adds additional complexity
        lengths = lengths[lengths > 0]

        # trim to about num_to_mask tokens
        idx = np.argmin(np.abs(np.cumsum(lengths, 0) - num_to_mask)) + 1
        lengths = lengths[:idx + 1]

        # select span start indices
        token_indices = np.argwhere(is_token == 1)
        span_starts = permutation(token_indices.shape[0])[:lengths.shape[0]]

        # prepare mask
        masked_indices = np.array(token_indices[span_starts])
        mask = np.full_like(labels, fill_value=False)

        # mask span start indices
        for mi in masked_indices:
            mask[tuple(mi)] = True
        lengths -= 1

        # fill up spans
        max_index = labels.shape[1] - 1
        remaining = (lengths > 0) & (masked_indices[:, 1] < max_index)
        while np.any(remaining):
            masked_indices[remaining, 1] += 1
            for mi in masked_indices:
                mask[tuple(mi)] = True
            lengths -= 1
            remaining = (lengths > 0) & (masked_indices[:, 1] < max_index)

        # place the mask tokens
        mask[np.where(special_tokens_mask)] = False
        inputs[np.where(mask)] = self.tokenizer.mask_token_id

        if not do_permutate:
            labels[np.where(mask)] = -100
        else:
            labels[np.where(special_tokens_mask)] = -100

        # remove mask tokens that are not starts of spans
        to_remove = (mask == 1) & np.roll((mask == 1), 1, 1)
        new_inputs = np.full_like(labels,
                                  fill_value=self.tokenizer.pad_token_id)

        # splits = list(map(lambda x: x.reshape(-1),  np.split(inputs_copy,
        # indices_or_sections=2, axis=0))
        for i, example in enumerate(
                np.split(inputs,
                         indices_or_sections=new_inputs.shape[0],
                         axis=0)):
            new_example = example[0][~to_remove[i]]
            new_inputs[i, 0:new_example.shape[0]] = new_example

        # batching now fixed
        return new_inputs, labels


class DataCollator(object):
    def __init__(self,
                 tokenizer,
                 mlm_probability=0.15,
                 mask_ratio=0.3,
                 poisson_lambda=3.0,
                 permutate_sentence_ratio=1.0):
        self.mlm_collator = DataCollatorForMLM(tokenizer, mlm_probability)
        self.denoise_collator = DataCollatorForDenoisingReconstrcution(
            tokenizer, mask_ratio, poisson_lambda, permutate_sentence_ratio)

    def __call__(self, examples):
        mlm_results = self.mlm_collator(examples)
        denoise_results = self.denoise_collator(examples)
        return {'mlm': mlm_results, 'denoise': denoise_results}
