from functools import lru_cache

import numpy as np
import torch

from fairseq.data import data_utils, Dictionary

from . import BaseWrapperDataset, LRUCacheDataset


class TnfMaskTokensDataset(BaseWrapperDataset):
    """
    A wrapper Dataset for masked language modeling.

    Input items are masked according to the specified masking probability.

    Args:
        dataset: Dataset to wrap.
        sizes: Sentence lengths
        vocab: Dictionary with the vocabulary and special tokens.
        pad_idx: Id of pad token in vocab
        mask_idx: Id of mask token in vocab
        return_masked_tokens: controls whether to return the non-masked tokens
            (the default) or to return a tensor with the original masked token
            IDs (and *pad_idx* elsewhere). The latter is useful as targets for
            masked LM training.
        seed: Seed for random number generator for reproducibility.
        mask_prob: probability of replacing a token with *mask_idx*.
        leave_unmasked_prob: probability that a masked token is unmasked.
        random_token_prob: probability of replacing a masked token with a
            random token from the vocabulary.
        freq_weighted_replacement: sample random replacement words based on
            word frequencies in the vocab.
        mask_whole_words: only mask whole words. This should be a byte mask
            over vocab indices, indicating whether it is the beginning of a
            word. We will extend any mask to encompass the whole word.
        bpe: BPE to use for whole-word masking.
    """

    @classmethod
    def apply_mask(cls, dataset: torch.utils.data.Dataset,
                   tnf_dataset: torch.utils.data.Dataset,
                   tnf_subword_dataset: torch.utils.data.Dataset,
                   is_subword_idx: int,
                   source_dictionary: Dictionary,
                   tnf_source_dictionary: Dictionary,
                   pad_idx: int,
                   tnf_pad_idx: int,
                   mask_idx: int,
                   tnf_mask_idx: int,
                   seed: int,
                   mask_prob: float,
                   leave_unmasked_prob: float,
                   random_token_prob: float,
                   freq_weighted_replacement: bool,
                   mask_whole_words: torch.Tensor = None):
        """Return the source and target datasets for masked LM training."""
        dataset = LRUCacheDataset(dataset)
        tnf_dataset = LRUCacheDataset(tnf_dataset)
        
        # Do not do leave unmask and random replacement for tnf dataset
        tnf_leave_unmasked_prob = 0.0
        tnf_random_token_prob = 0.0
        return (
            LRUCacheDataset(cls(dataset, tnf_subword_dataset, is_subword_idx, source_dictionary, pad_idx, mask_idx, seed,
                                mask_prob, leave_unmasked_prob, random_token_prob,
                                freq_weighted_replacement, mask_whole_words,
                                return_masked_tokens=False)),
            LRUCacheDataset(cls(tnf_dataset, tnf_subword_dataset, is_subword_idx, tnf_source_dictionary, tnf_pad_idx,
                                tnf_mask_idx, seed, mask_prob, tnf_leave_unmasked_prob,
                                tnf_random_token_prob, freq_weighted_replacement,
                                mask_whole_words, return_masked_tokens=False)),
            LRUCacheDataset(
                cls(tnf_dataset, tnf_subword_dataset, is_subword_idx, tnf_source_dictionary, tnf_pad_idx,
                    tnf_mask_idx, seed, mask_prob, tnf_leave_unmasked_prob,
                    tnf_random_token_prob, freq_weighted_replacement,
                    mask_whole_words, return_masked_tokens=False, return_tnf_nomask=True)),
            LRUCacheDataset(cls(dataset, tnf_subword_dataset, is_subword_idx, source_dictionary, pad_idx, mask_idx, seed,
                                mask_prob, leave_unmasked_prob, random_token_prob,
                                freq_weighted_replacement, mask_whole_words,
                                return_masked_tokens=True)),
        )

    def __init__(
        self,
        dataset: torch.utils.data.Dataset,
        subword_dataset: torch.utils.data.Dataset,
        is_subword_idx: int,
        vocab: Dictionary,
        pad_idx: int,
        mask_idx: int,
        seed: int = 1,
        mask_prob: float = 0.15,
        leave_unmasked_prob: float = 0.1,
        random_token_prob: float = 0.1,
        freq_weighted_replacement: bool = False,
        mask_whole_words: torch.Tensor = None,
        return_masked_tokens: bool = False,
        return_tnf_nomask: bool = False,

    ):
        assert 0.0 < mask_prob < 1.0
        assert 0.0 <= random_token_prob <= 1.0
        assert 0.0 <= leave_unmasked_prob <= 1.0
        assert random_token_prob + leave_unmasked_prob <= 1.0

        self.dataset = dataset
        self.subword_dataset = subword_dataset
        self.vocab = vocab
        self.pad_idx = pad_idx
        self.mask_idx = mask_idx
        self.return_masked_tokens = return_masked_tokens
        self.seed = seed
        self.mask_prob = mask_prob
        self.leave_unmasked_prob = leave_unmasked_prob
        self.random_token_prob = random_token_prob
        self.mask_whole_words = mask_whole_words
        self.is_subword_idx = is_subword_idx
        self.return_tnf_nomask = return_tnf_nomask

        if random_token_prob > 0.0:
            if freq_weighted_replacement:
                weights = np.array(self.vocab.count)
            else:
                weights = np.ones(len(self.vocab))
            weights[:self.vocab.nspecial] = 0
            self.weights = weights / weights.sum()

        self.epoch = 0

    def set_epoch(self, epoch, **unused):
        self.epoch = epoch

    @lru_cache(maxsize=8)
    def __getitem__(self, index: int):
        with data_utils.numpy_seed(self.seed, self.epoch, index):
            item = self.dataset[index]
            sz = len(item)
            if self.subword_dataset is not None:
                subword_item = self.subword_dataset[index]

            assert self.mask_idx not in item, \
                'Dataset contains mask_idx (={}), this is not expected!'.format(
                    self.mask_idx,
                )

            if self.mask_whole_words is not None:
                word_begins_mask = self.mask_whole_words.gather(0, item)
                word_begins_idx = word_begins_mask.nonzero().view(-1)
                sz = len(word_begins_idx)
                words = np.split(word_begins_mask, word_begins_idx)[1:]
                assert len(words) == sz
                word_lens = list(map(len, words))

            # decide elements to mask
            mask = np.full(sz, False)
            num_mask = int(
                # add a random number for probabilistic rounding
                self.mask_prob * sz + np.random.rand()
            )
            if self.return_tnf_nomask:
                new_item = np.copy(item)
                return torch.from_numpy(new_item)

            mask[np.random.choice(sz, num_mask, replace=False)] = True

            if self.return_masked_tokens:
                # exit early if we're just returning the masked tokens
                # (i.e., the targets for masked LM training)
                if self.mask_whole_words is not None:
                    mask = np.repeat(mask, word_lens)
                if self.subword_dataset is not None:
                    subword_idx = np.argwhere((subword_item.numpy()==self.is_subword_idx)&(mask==True))
                    if len(subword_idx)>0:
                        subword_idx = subword_idx.flatten()
                new_item = np.full(len(mask), self.pad_idx)
                new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1]
                if self.subword_dataset is not None:
                    # mask its siblings.
                    for id in subword_idx:
                        for i in range(id+1, sz):
                            if subword_item[i]==self.is_subword_idx:
                                new_item[i] = item[i]
                            else:
                                break
                        for i in range(id-1, -1,-1):
                            if subword_item[i]==self.is_subword_idx:
                                new_item[i] = item[i]
                            else:
                                break
                return torch.from_numpy(new_item)

            # decide unmasking and random replacement
            rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
            if rand_or_unmask_prob > 0.0:
                rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
                if self.random_token_prob == 0.0:
                    unmask = rand_or_unmask
                    rand_mask = None
                elif self.leave_unmasked_prob == 0.0:
                    unmask = None
                    rand_mask = rand_or_unmask
                else:
                    unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
                    decision = np.random.rand(sz) < unmask_prob
                    unmask = rand_or_unmask & decision
                    rand_mask = rand_or_unmask & (~decision)
            else:
                unmask = rand_mask = None

            if unmask is not None:
                mask = mask ^ unmask

            if self.mask_whole_words is not None:
                mask = np.repeat(mask, word_lens)

            new_item = np.copy(item)
            new_item[mask] = self.mask_idx
            if rand_mask is not None:
                num_rand = rand_mask.sum()
                if num_rand > 0:
                    if self.mask_whole_words is not None:
                        rand_mask = np.repeat(rand_mask, word_lens)
                        num_rand = rand_mask.sum()

                    new_item[rand_mask] = np.random.choice(
                        len(self.vocab),
                        num_rand,
                        p=self.weights,
                    )
            # mask the siblings of the masked subwords to avoid information leakage!
            #if rand_or_unmask_prob > 0.0:
            #print(rand_or_unmask_prob)
            if rand_or_unmask_prob > 0.0:
                final_mask = mask ^ rand_or_unmask
            else:
                final_mask = mask
            if self.subword_dataset is not None:
                # get the index of the masked subwords
                subword_idx = np.argwhere((subword_item.numpy()==self.is_subword_idx)&(final_mask==True))
                if len(subword_idx)>0:
                    subword_idx = subword_idx.flatten()
                # mask its siblings.
                for id in subword_idx:
                    for i in range(id+1, sz):
                        if subword_item[i]==self.is_subword_idx:
                            new_item[i] = self.mask_idx
                        else:
                            break
                    for i in range(id-1, -1,-1):
                        if subword_item[i]==self.is_subword_idx:
                            new_item[i] = self.mask_idx
                        else:
                            break
            # ? for every masked subword, decode the location of its siblings and mask them all
            return torch.from_numpy(new_item)
