"""
BackTranslation class
-----------------------------------

"""


import random

from transformers import MarianMTModel, MarianTokenizer

from textattack.shared import AttackedText

from .sentence_transformation import SentenceTransformation


class BackTranslation(SentenceTransformation):
    """A type of sentence level transformation that takes in a text input,
    translates it into target language and translates it back to source
    language.

    letters_to_insert (string): letters allowed for insertion into words
    (used by some char-based transformations)

    src_lang (string): source language
    target_lang (string): target language, for the list of supported language check bottom of this page
    src_model: translation model from huggingface that translates from source language to target language
    target_model: translation model from huggingface that translates from target language to source language
    chained_back_translation: run back translation in a chain for more perturbation (for example, en-es-en-fr-en)

    Example::

        >>> from textattack.transformations.sentence_transformations import BackTranslation
        >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
        >>> from textattack.augmentation import Augmenter

        >>> transformation = BackTranslation()
        >>> constraints = [RepeatModification(), StopwordModification()]
        >>> augmenter = Augmenter(transformation = transformation, constraints = constraints)
        >>> s = 'What on earth are you doing here.'

        >>> augmenter.augment(s)
    """

    def __init__(
        self,
        src_lang="en",
        target_lang="es",
        src_model="Helsinki-NLP/opus-mt-ROMANCE-en",
        target_model="Helsinki-NLP/opus-mt-en-ROMANCE",
        chained_back_translation=0,
    ):
        self.src_lang = src_lang
        self.target_lang = target_lang
        self.target_model = MarianMTModel.from_pretrained(target_model)
        self.target_tokenizer = MarianTokenizer.from_pretrained(target_model)
        self.src_model = MarianMTModel.from_pretrained(src_model)
        self.src_tokenizer = MarianTokenizer.from_pretrained(src_model)
        self.chained_back_translation = chained_back_translation

    def translate(self, input, model, tokenizer, lang="es"):
        # change the text to model's format
        src_texts = []
        if lang == "en":
            src_texts.append(input[0])
        else:
            if ">>" and "<<" not in lang:
                lang = ">>" + lang + "<< "
            src_texts.append(lang + input[0])

        # tokenize the input
        encoded_input = tokenizer.prepare_seq2seq_batch(src_texts, return_tensors="pt")

        # translate the input
        translated = model.generate(**encoded_input)
        translated_input = tokenizer.batch_decode(translated, skip_special_tokens=True)
        return translated_input

    def _get_transformations(self, current_text, indices_to_modify):
        transformed_texts = []
        current_text = current_text.text

        # to perform chained back translation, a random list of target languages are selected from the provided model
        if self.chained_back_translation:
            list_of_target_lang = random.sample(
                self.target_tokenizer.supported_language_codes,
                self.chained_back_translation,
            )
            for target_lang in list_of_target_lang:
                target_language_text = self.translate(
                    [current_text],
                    self.target_model,
                    self.target_tokenizer,
                    target_lang,
                )
                src_language_text = self.translate(
                    target_language_text,
                    self.src_model,
                    self.src_tokenizer,
                    self.src_lang,
                )
                current_text = src_language_text[0]
            return [AttackedText(current_text)]

        # translates source to target language and back to source language (single back translation)
        target_language_text = self.translate(
            [current_text], self.target_model, self.target_tokenizer, self.target_lang
        )
        src_language_text = self.translate(
            target_language_text, self.src_model, self.src_tokenizer, self.src_lang
        )
        transformed_texts.append(AttackedText(src_language_text[0]))
        return transformed_texts


"""
List of supported languages
['fr',
 'es',
 'it',
 'pt',
 'pt_br',
 'ro',
 'ca',
 'gl',
 'pt_BR<<',
 'la<<',
 'wa<<',
 'fur<<',
 'oc<<',
 'fr_CA<<',
 'sc<<',
 'es_ES',
 'es_MX',
 'es_AR',
 'es_PR',
 'es_UY',
 'es_CL',
 'es_CO',
 'es_CR',
 'es_GT',
 'es_HN',
 'es_NI',
 'es_PA',
 'es_PE',
 'es_VE',
 'es_DO',
 'es_EC',
 'es_SV',
 'an',
 'pt_PT',
 'frp',
 'lad',
 'vec',
 'fr_FR',
 'co',
 'it_IT',
 'lld',
 'lij',
 'lmo',
 'nap',
 'rm',
 'scn',
 'mwl']
"""
