# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import torch
import numpy as np

from fairseq.data import data_utils
from fairseq.data import encoders
from fairseq.data.encoders import register_tagger
from fairseq.data.encoders.tagger import Tagger
from fairseq.data.encoders.utils import get_whole_word_mask


@register_tagger('denoising_tagger')
class DenoisingTagger(Tagger):
    """
    Denoising tagger for applying sequence to sequence denoising. (ie. BART)
    """

    @staticmethod
    def add_args(parser):

        # fmt: off
        Tagger.add_args(parser)
        parser.add_argument(
            "--mask",
            default=0.0,
            type=float,
            help="fraction of words/subwords that will be masked",
        )
        parser.add_argument(
            "--mask-random",
            default=0.0,
            type=float,
            help="instead of using [MASK], use random token this often",
        )
        parser.add_argument(
            "--insert",
            default=0.0,
            type=float,
            help="insert this percentage of additional random tokens",
        )
        parser.add_argument(
            "--permute",
            default=0.0,
            type=float,
            help="take this proportion of subwords and permute them",
        )
        parser.add_argument(
            "--rotate",
            default=0.0,
            type=float,
            help="rotate this proportion of inputs",
        )
        parser.add_argument(
            "--poisson-lambda",
            default=3.0,
            type=float,
            help="randomly shuffle sentences for this proportion of inputs",
        )
        parser.add_argument(
            "--permute-sentences",
            default=0.0,
            type=float,
            help="shuffle this proportion of sentences in all inputs",
        )
        parser.add_argument(
            "--mask-length",
            default="subword",
            type=str,
            choices=["subword", "word", "span-poisson"],
            help="mask length to choose",
        )
        parser.add_argument(
            "--replace-length",
            default=1,
            type=int,
            help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
        )
        parser.add_argument(
            "--no-whole-word-mask-langs",
            type=str,
            default="",
            metavar="N",
            help="languages without spacing between words dont support whole word masking",
        )
        # fmt: on

    def __init__(self, args):
        super().__init__(args)
        self.args = args
        self.language_without_segmentations = self.args.no_whole_word_mask_langs.split(",")
        self.mask_ratio = args.mask
        self.random_ratio = args.mask_random
        self.insert_ratio = args.insert
        self.rotate_ratio = args.rotate
        self.permute_sentence_ratio = args.permute_sentences
        self.permute_ratio = args.permute
        self.bpe = encoders.build_bpe(args)

        self.replace_length = args.replace_length
        if self.replace_length not in [-1, 0, 1]:
            raise ValueError(f"invalid arg: replace_length={self.replace_length}")
        if args.mask_length not in ["subword", "word", "span-poisson"]:
            raise ValueError(f"invalid arg: mask-length={args.mask_length}")
        if args.mask_length == "subword" and args.replace_length not in [0, 1]:
            raise ValueError(f"if using subwords, use replace-length=1 or 0")

        self.mask_span_distribution = None
        if args.mask_length == "span-poisson":
            _lambda = args.poisson_lambda

            lambda_to_the_k = 1
            e_to_the_minus_lambda = math.exp(-_lambda)
            k_factorial = 1
            ps = []
            for k in range(0, 128):
                ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
                lambda_to_the_k *= _lambda
                k_factorial *= k + 1
                if ps[-1] < 0.0000001:
                    break
            ps = torch.FloatTensor(ps)
            self.mask_span_distribution = torch.distributions.Categorical(ps)

    def encode(self, x: str, meta: dict = {}, **kwargs) -> str:
        """
        Prepend tags to the text sequence x.
        Meta should contain the necessary metadata (lang code, target lang code, corpus tag)
        """
        x = super().encode(x, meta, **kwargs)
        source = np.array(x.split(' '),  dtype=object)

        vocab = meta['dictionary']
        # Assuming the mask token is at the end of the src_dictionary
        mask_token = '<mask>'
        language = meta['lang']
        lang_mask_whole_words = True if language not in self.language_without_segmentations else False

        '''eos = (
            f"<{self.args.lang_code_prefix}{meta['lang']}>"
            if self.args.lang_code and self.args.append_codes
            else vocab.eos_word
        )'''
        #print(x, 'x')
        #print(source, 'np')
        #eos = vocab.eos()
        #print(eos, 'eos')
        #assert source[-1] == eos

        # TODO: numpy SEEDing (see denoising_dataset.py)
        #if self.permute_sentence_ratio > 0.0:
        #    source = self.permute_sentences(source, vocab, self.permute_sentence_ratio)

        if self.mask_ratio > 0:
            source = self.add_whole_word_mask(source, vocab, mask_token, lang_mask_whole_words, self.mask_ratio)

        if self.insert_ratio > 0:
            source = self.add_insertion_noise(source, vocab, mask_token, self.insert_ratio)

        if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
            source = self.add_rolling_noise(source)

        if self.permute_ratio > 0.0:
            source = self.add_permuted_noise(source, self.permute_ratio)

        if getattr(self.args, "prepend_bos", False):
            assert source[0] == vocab.bos_word
        #assert source[-1] == eos

        return ' '.join(source)

    def decode(self, x: str) -> str:
        return super().decode(x)

    '''def permute_sentences(self, tokens, vocab, p=1.0):
        # TODO: Look at GPT2 vocabulary eos()
        full_stop_index = vocab.eos_word

        full_stops = tokens == full_stop_index
        # Pretend it ends with a full stop so last span is a sentence
        full_stops[-2] = 1

        # Tokens that are full stops, where the previous token is not
        sentence_ends = np.flatnonzero(full_stops[1:] * ~full_stops[:-1]) + 2
        result = tokens.copy()

        num_sentences = len(sentence_ends)
        num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
        substitutions = torch.randperm(num_sentences)[:num_to_permute]
        ordering = torch.arange(0, num_sentences)
        ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]

        # Ignore <bos> at start
        index = 1
        for i in ordering:
            sentence = tokens[(sentence_ends[i - 1] if i > 0 else 1): sentence_ends[i]]
            result[index: index + len(sentence)] = sentence
            index += len(sentence)
        return result'''

    def word_starts(self, tokens, mask_whole_word):
        if mask_whole_word:
            is_word_start = torch.ByteTensor([self.bpe.is_beginning_of_word(x) for x in tokens])
        else:
            is_word_start = torch.ones(len(tokens))
        is_word_start[0] = 0
        is_word_start[-1] = 0
        return is_word_start

    def add_whole_word_mask(self, tokens, vocab, mask_token, mask_whole_word, p):
        is_word_start = self.word_starts(tokens, mask_whole_word)
        num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
        num_inserts = 0
        if num_to_mask == 0:
            return tokens

        if self.mask_span_distribution is not None:
            lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))

            # Make sure we have enough to mask
            cum_length = torch.cumsum(lengths, 0)
            while cum_length[-1] < num_to_mask:
                lengths = torch.cat(
                    [
                        lengths,
                        self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
                    ],
                    dim=0,
                )
                cum_length = torch.cumsum(lengths, 0)

            # Trim to masking budget
            i = 0
            while cum_length[i] < num_to_mask:
                i += 1
            lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
            num_to_mask = i + 1
            lengths = lengths[:num_to_mask]

            # Handle 0-length mask (inserts) separately
            lengths = lengths[lengths > 0]
            num_inserts = num_to_mask - lengths.size(0)
            num_to_mask -= num_inserts
            if num_to_mask == 0:
                return self.add_insertion_noise(tokens, num_inserts / len(tokens))

            assert (lengths > 0).all()
        else:
            lengths = torch.ones((num_to_mask,)).long()
        assert is_word_start[-1] == 0
        word_starts = is_word_start.nonzero(as_tuple=False)
        indices = word_starts[
            torch.randperm(word_starts.size(0))[:num_to_mask]
        ].squeeze(1)
        mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio

        tokens_length = len(tokens)
        assert tokens_length - 1 not in indices
        to_keep = torch.ones(tokens_length, dtype=torch.bool)
        is_word_start[
            -1
        ] = 255  # acts as a long length, so spans don't go over the end of doc
        if self.replace_length == 0:
            to_keep[indices] = 0
        else:
            # keep index, but replace it with [MASK]
            tokens.put(indices, mask_token)
            tokens.put(indices[mask_random],
                       [vocab[i] for i in np.random.randint(low=1, high=len(vocab), size=(mask_random.sum(),))])

        if self.mask_span_distribution is not None:
            assert len(lengths.size()) == 1
            assert lengths.size() == indices.size()
            lengths -= 1
            while indices.size(0) > 0:
                assert lengths.size() == indices.size()
                lengths -= is_word_start[indices + 1].long()
                uncompleted = lengths >= 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                lengths = lengths[uncompleted]
                if self.replace_length != -1:
                    # delete token
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]
                    tokens.put(indices, mask_token)
                    tokens.put(indices[mask_random],
                               [vocab[i] for i in np.random.randint(low=1, high=len(vocab), size=(mask_random.sum(),))])
        else:
            # A bit faster when all lengths are 1
            while indices.size(0) > 0:
                uncompleted = is_word_start[indices + 1] == 0
                indices = indices[uncompleted] + 1
                mask_random = mask_random[uncompleted]
                if self.replace_length != -1:
                    # delete token
                    to_keep[indices] = 0
                else:
                    # keep index, but replace it with [MASK]
                    tokens.put(indices, mask_token)
                    tokens.put(indices[mask_random],
                               [vocab[i] for i in np.random.randint(low=1, high=len(vocab), size=(mask_random.sum(),))])

                assert tokens_length - 1 not in indices

        tokens = tokens[to_keep]

        if num_inserts > 0:
            tokens = self.add_insertion_noise(tokens, vocab, mask_token, num_inserts/len(tokens))

        return tokens

    def add_permuted_noise(self, tokens, p):
        num_words = len(tokens)
        num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
        substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
        tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
        return tokens

    def add_rolling_noise(self, tokens):
        offset = np.random.randint(1, max(1, len(tokens) - 1) + 1)
        tokens = np.concatenate(
            (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
            axis=0,
        )
        return tokens

    def add_insertion_noise(self, tokens, vocab, mask_token, p):
        if p == 0.0:
            return tokens

        num_tokens = len(tokens)
        n = int(math.ceil(num_tokens * p))

        noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
        noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
        noise_mask[noise_indices] = 1
        result = np.array(['<null>']*(n+len(tokens)),  dtype=object)

        num_random = int(math.ceil(n * self.random_ratio))
        result.put(noise_indices[num_random:], mask_token)
        result.put(noise_indices[:num_random],
                   [vocab[i] for i in np.random.randint(low=1, high=len(vocab), size=(num_random,))])

        result[~noise_mask] = tokens

        assert (result != '<null>').all()
        return result
