"""
BAE (BAE: BERT-Based Adversarial Examples)
============================================

"""
from textattack.constraints.grammaticality import PartOfSpeech
from textattack.constraints.pre_transformation import (
    RepeatModification,
    StopwordModification,
)
from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder
from textattack.goal_functions import UntargetedClassification
from textattack.search_methods import GreedyWordSwapWIR
from textattack.transformations import WordSwapMaskedLM

from .attack_recipe import AttackRecipe


class BAEGarg2019(AttackRecipe):
    """Siddhant Garg and Goutham Ramakrishnan, 2019.

    BAE: BERT-based Adversarial Examples for Text Classification.

    https://arxiv.org/pdf/2004.01970

    This is "attack mode" 1 from the paper, BAE-R, word replacement.

    We present 4 attack modes for BAE based on the
        R and I operations, where for each token t in S:
        • BAE-R: Replace token t (See Algorithm 1)
        • BAE-I: Insert a token to the left or right of t
        • BAE-R/I: Either replace token t or insert a
        token to the left or right of t
        • BAE-R+I: First replace token t, then insert a
        token to the left or right of t
    """

    @staticmethod
    def build(model_wrapper):
        # "In this paper, we present a simple yet novel technique: BAE (BERT-based
        # Adversarial Examples), which uses a language model (LM) for token
        # replacement to best fit the overall context. We perturb an input sentence
        # by either replacing a token or inserting a new token in the sentence, by
        # means of masking a part of the input and using a LM to fill in the mask."
        #
        # We only consider the top K=50 synonyms from the MLM predictions.
        #
        # [from email correspondance with the author]
        # "When choosing the top-K candidates from the BERT masked LM, we filter out
        # the sub-words and only retain the whole words (by checking if they are
        # present in the GloVE vocabulary)"
        #
        transformation = WordSwapMaskedLM(
            method="bae", max_candidates=50, min_confidence=0.0
        )
        #
        # Don't modify the same word twice or stopwords.
        #
        constraints = [RepeatModification(), StopwordModification()]

        # For the R operations we add an additional check for
        # grammatical correctness of the generated adversarial example by filtering
        # out predicted tokens that do not form the same part of speech (POS) as the
        # original token t_i in the sentence.
        constraints.append(PartOfSpeech(allow_verb_noun_swap=True))

        # "To ensure semantic similarity on introducing perturbations in the input
        # text, we filter the set of top-K masked tokens (K is a pre-defined
        # constant) predicted by BERT-MLM using a Universal Sentence Encoder (USE)
        # (Cer et al., 2018)-based sentence similarity scorer."
        #
        # "[We] set a threshold of 0.8 for the cosine similarity between USE-based
        # embeddings of the adversarial and input text."
        #
        # [from email correspondence with the author]
        # "For a fair comparison of the benefits of using a BERT-MLM in our paper,
        # we retained the majority of TextFooler's specifications. Thus we:
        # 1. Use the USE for comparison within a window of size 15 around the word
        # being replaced/inserted.
        # 2. Set the similarity score threshold to 0.1 for inputs shorter than the
        # window size (this translates roughly to almost always accepting the new text).
        # 3. Perform the USE similarity thresholding of 0.8 with respect to the text
        # just before the replacement/insertion and not the original text (For
        # example: at the 3rd R/I operation, we compute the USE score on a window
        # of size 15 of the text obtained after the first 2 R/I operations and not
        # the original text).
        # ...
        # To address point (3) from above, compare the USE with the original text
        # at each iteration instead of the current one (While doing this change
        # for the R-operation is trivial, doing it for the I-operation with the
        # window based USE comparison might be more involved)."
        #
        # Finally, since the BAE code is based on the TextFooler code, we need to
        # adjust the threshold to account for the missing / pi in the cosine
        # similarity comparison. So the final threshold is 1 - (1 - 0.8) / pi
        # = 1 - (0.2 / pi) = 0.936338023.
        use_constraint = UniversalSentenceEncoder(
            threshold=0.936338023,
            metric="cosine",
            compare_against_original=True,
            window_size=15,
            skip_text_shorter_than_window=True,
        )
        constraints.append(use_constraint)
        #
        # Goal is untargeted classification.
        #
        goal_function = UntargetedClassification(model_wrapper)
        #
        # "We estimate the token importance Ii of each token
        # t_i ∈ S = [t1, . . . , tn], by deleting ti from S and computing the
        # decrease in probability of predicting the correct label y, similar
        # to (Jin et al., 2019).
        #
        # • "If there are multiple tokens can cause C to misclassify S when they
        # replace the mask, we choose the token which makes Sadv most similar to
        # the original S based on the USE score."
        # • "If no token causes misclassification, we choose the perturbation that
        # decreases the prediction probability P(C(Sadv)=y) the most."
        #
        search_method = GreedyWordSwapWIR(wir_method="delete")

        return BAEGarg2019(goal_function, constraints, transformation, search_method)
