# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from functools import lru_cache

import numpy as np
import torch
from fairseq.data import Dictionary, data_utils

from . import BaseWrapperDataset, LRUCacheDataset


class MaskTokensDataset(BaseWrapperDataset):
    """
    A wrapper Dataset for masked language modeling.

    Input items are masked according to the specified masking probability.

    Args:
        dataset: Dataset to wrap.
        sizes: Sentence lengths
        vocab: Dictionary with the vocabulary and special tokens.
        pad_idx: Id of pad token in vocab
        mask_idx: Id of mask token in vocab
        return_masked_tokens: controls whether to return the non-masked tokens
            (the default) or to return a tensor with the original masked token
            IDs (and *pad_idx* elsewhere). The latter is useful as targets for
            masked LM training.
        seed: Seed for random number generator for reproducibility.
        mask_prob: probability of replacing a token with *mask_idx*.
        leave_unmasked_prob: probability that a masked token is unmasked.
        random_token_prob: probability of replacing a masked token with a
            random token from the vocabulary.
        freq_weighted_replacement: sample random replacement words based on
            word frequencies in the vocab.
        mask_whole_words: only mask whole words. This should be a byte mask
            over vocab indices, indicating whether it is the beginning of a
            word. We will extend any mask to encompass the whole word.
        bpe: BPE to use for whole-word masking.
        mask_multiple_length : repeat each mask index multiple times. Default
            value is 1.
        mask_stdev : standard deviation of masks distribution in case of
            multiple masking. Default value is 0.
    """

    @classmethod
    def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs):
        """Return the source and target datasets for masked LM training."""
        dataset = LRUCacheDataset(dataset)
        return (
            LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)),
            LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)),
        )

    def __init__(
        self,
        dataset: torch.utils.data.Dataset,
        vocab: Dictionary,
        pad_idx: int,
        mask_idx: int,
        return_masked_tokens: bool = False,
        seed: int = 1,
        mask_prob: float = 0.15,
        leave_unmasked_prob: float = 0.1,
        random_token_prob: float = 0.1,
        freq_weighted_replacement: bool = False,
        mask_whole_words: torch.Tensor = None,
        mask_multiple_length: int = 1,
        mask_stdev: float = 0.0,
        skip_masking: bool = False,
    ):
        assert 0.0 < mask_prob < 1.0
        assert 0.0 <= random_token_prob <= 1.0
        assert 0.0 <= leave_unmasked_prob <= 1.0
        assert random_token_prob + leave_unmasked_prob <= 1.0
        assert mask_multiple_length >= 1
        assert mask_stdev >= 0.0

        self.dataset = dataset
        self.vocab = vocab
        self.pad_idx = pad_idx
        self.mask_idx = mask_idx
        self.return_masked_tokens = return_masked_tokens
        self.seed = seed
        self.mask_prob = mask_prob
        self.leave_unmasked_prob = leave_unmasked_prob
        self.random_token_prob = random_token_prob
        self.mask_whole_words = mask_whole_words
        self.mask_multiple_length = mask_multiple_length
        self.mask_stdev = mask_stdev
        self.skip_masking = skip_masking

        if random_token_prob > 0.0:
            if freq_weighted_replacement:
                weights = np.array(self.vocab.count)
            else:
                weights = np.ones(len(self.vocab))
            weights[: self.vocab.nspecial] = 0
            self.weights = weights / weights.sum()

        self.epoch = 0

    @property
    def can_reuse_epoch_itr_across_epochs(self):
        return True  # only the noise changes, not item sizes

    def set_epoch(self, epoch, **unused):
        super().set_epoch(epoch)
        self.epoch = epoch

    def __getitem__(self, index: int):
        return self.__getitem_cached__(self.seed, self.epoch, index)

    @lru_cache(maxsize=8)
    def __getitem_cached__(self, seed: int, epoch: int, index: int):
        seed = int(hash((seed, epoch, index)) % 1e6)
        rng = np.random.default_rng(seed)
        item = self.dataset[index]
        sz = len(item)

        assert (
            self.mask_idx not in item
        ), "Dataset contains mask_idx (={}), this is not expected!".format(
            self.mask_idx,
        )
        if self.skip_masking:
            return torch.from_numpy(np.copy(item))

        if self.mask_whole_words is not None:
            word_begins_mask = self.mask_whole_words.gather(0, item)
            word_begins_idx = word_begins_mask.nonzero().view(-1)
            sz = len(word_begins_idx)
            words = np.split(word_begins_mask, word_begins_idx)[1:]
            assert len(words) == sz
            word_lens = list(map(len, words))

        # decide elements to mask
        mask = np.full(sz, False)
        num_mask = int(
            # add a random number for probabilistic rounding
            self.mask_prob * sz / float(self.mask_multiple_length)
            + rng.random()
        )

        # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453)
        mask_idc = rng.choice(sz, num_mask, replace=False)
        if self.mask_stdev > 0.0:
            lengths = rng.normal(
                self.mask_multiple_length, self.mask_stdev, size=num_mask
            )
            lengths = [max(0, int(round(x))) for x in lengths]
            mask_idc = np.asarray(
                [
                    mask_idc[j] + offset
                    for j in range(len(mask_idc))
                    for offset in range(lengths[j])
                ],
                dtype=np.int64,
            )
        else:
            mask_idc = np.concatenate(
                [mask_idc + i for i in range(self.mask_multiple_length)]
            )
        mask_idc = mask_idc[mask_idc < len(mask)]
        try:
            mask[mask_idc] = True
        except:  # something wrong
            print("Assigning mask indexes {} to mask {} failed!".format(mask_idc, mask))
            raise

        # if self.return_masked_tokens:
        #     print((
        #         f"IDX={index}; seed={seed}; epoch={epoch}; is_tgt={self.return_masked_tokens}: "
        #         f"{np.nonzero(mask)[0].sum()}"
        #     ))
        if self.return_masked_tokens:
            # exit early if we're just returning the masked tokens
            # (i.e., the targets for masked LM training)
            if self.mask_whole_words is not None:
                mask = np.repeat(mask, word_lens)
            new_item = np.full(len(mask), self.pad_idx)
            new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
            return torch.from_numpy(new_item)

        # decide unmasking and random replacement
        rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
        if rand_or_unmask_prob > 0.0:
            rand_or_unmask = mask & (rng.random(sz) < rand_or_unmask_prob)
            if self.random_token_prob == 0.0:
                unmask = rand_or_unmask
                rand_mask = None
            elif self.leave_unmasked_prob == 0.0:
                unmask = None
                rand_mask = rand_or_unmask
            else:
                unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
                decision = rng.random(sz) < unmask_prob
                unmask = rand_or_unmask & decision
                rand_mask = rand_or_unmask & (~decision)
        else:
            unmask = rand_mask = None

        if unmask is not None:
            mask = mask ^ unmask

        if self.mask_whole_words is not None:
            mask = np.repeat(mask, word_lens)

        new_item = np.copy(item)
        new_item[mask] = self.mask_idx
        if rand_mask is not None:
            num_rand = rand_mask.sum()
            if num_rand > 0:
                if self.mask_whole_words is not None:
                    rand_mask = np.repeat(rand_mask, word_lens)
                    num_rand = rand_mask.sum()

                new_item[rand_mask] = rng.choice(
                    len(self.vocab),
                    num_rand,
                    p=self.weights,
                )

        return torch.from_numpy(new_item)
