from __future__ import annotations

import json
import os
import shutil
import tempfile
import unittest
from unittest.mock import patch

import numpy as np

from transformers import BartTokenizer
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from transformers.utils import cached_property, is_datasets_available, is_faiss_available, is_tf_available


if is_tf_available() and is_datasets_available() and is_faiss_available():
    import faiss
    import tensorflow as tf
    from datasets import Dataset

    from transformers import (
        AutoConfig,
        RagConfig,
        RagRetriever,
        RagTokenizer,
        TFAutoModel,
        TFAutoModelForSeq2SeqLM,
        TFRagModel,
        TFRagSequenceForGeneration,
        TFRagTokenForGeneration,
    )
    from transformers.modeling_tf_outputs import TFBaseModelOutput

from ..bart.test_modeling_tf_bart import TFBartModelTester
from ..dpr.test_modeling_tf_dpr import TFDPRModelTester


TOLERANCE = 1e-3


def require_retrieval(test_case):
    """
    Decorator marking a test that requires a set of dependencies necessary for pefrorm retrieval with
    [`RagRetriever`].

    These tests are skipped when respective libraries are not installed.

    """
    if not (is_tf_available() and is_datasets_available() and is_faiss_available()):
        test_case = unittest.skip("test requires tensorflow, datasets and faiss")(test_case)
    return test_case


@require_tf
@require_retrieval
@require_sentencepiece
class TFRagTestMixin:
    all_model_classes = (
        (TFRagModel, TFRagTokenForGeneration, TFRagSequenceForGeneration)
        if is_tf_available() and is_datasets_available() and is_faiss_available()
        else ()
    )
    all_generative_model_classes = (
        (TFRagTokenForGeneration, TFRagSequenceForGeneration)
        if is_tf_available() and is_datasets_available() and is_faiss_available()
        else ()
    )

    retrieval_vector_size = 32
    n_docs = 3
    max_combined_length = 16

    def setUp(self):
        self.tmpdirname = tempfile.mkdtemp()

        # DPR tok
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "[PAD]",
            "[MASK]",
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
        ]
        dpr_tokenizer_path = os.path.join(self.tmpdirname, "dpr_tokenizer")
        os.makedirs(dpr_tokenizer_path, exist_ok=True)
        self.vocab_file = os.path.join(dpr_tokenizer_path, DPR_VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

        # BART tok
        vocab = [
            "l",
            "o",
            "w",
            "e",
            "r",
            "s",
            "t",
            "i",
            "d",
            "n",
            "\u0120",
            "\u0120l",
            "\u0120n",
            "\u0120lo",
            "\u0120low",
            "er",
            "\u0120lowest",
            "\u0120newer",
            "\u0120wider",
            "<unk>",
        ]
        vocab_tokens = dict(zip(vocab, range(len(vocab))))
        merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
        self.special_tokens_map = {"unk_token": "<unk>"}

        bart_tokenizer_path = os.path.join(self.tmpdirname, "bart_tokenizer")
        os.makedirs(bart_tokenizer_path, exist_ok=True)
        self.vocab_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["vocab_file"])
        self.merges_file = os.path.join(bart_tokenizer_path, BART_VOCAB_FILES_NAMES["merges_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as fp:
            fp.write(json.dumps(vocab_tokens) + "\n")
        with open(self.merges_file, "w", encoding="utf-8") as fp:
            fp.write("\n".join(merges))

    @cached_property
    def dpr_tokenizer(self) -> DPRQuestionEncoderTokenizer:
        return DPRQuestionEncoderTokenizer.from_pretrained(os.path.join(self.tmpdirname, "dpr_tokenizer"))

    @cached_property
    def bart_tokenizer(self) -> BartTokenizer:
        return BartTokenizer.from_pretrained(os.path.join(self.tmpdirname, "bart_tokenizer"))

    def tearDown(self):
        shutil.rmtree(self.tmpdirname)

    def get_retriever(self, config):
        dataset = Dataset.from_dict(
            {
                "id": ["0", "1", "3"],
                "text": ["foo", "bar", "qux"],
                "title": ["Foo", "Bar", "Qux"],
                "embeddings": [
                    np.ones(self.retrieval_vector_size),
                    2 * np.ones(self.retrieval_vector_size),
                    3 * np.ones(self.retrieval_vector_size),
                ],
            }
        )
        dataset.add_faiss_index("embeddings", string_factory="Flat", metric_type=faiss.METRIC_INNER_PRODUCT)
        tokenizer = self.bart_tokenizer
        with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
            mock_load_dataset.return_value = dataset
            retriever = RagRetriever(
                config,
                question_encoder_tokenizer=self.dpr_tokenizer,
                generator_tokenizer=tokenizer,
            )
        return retriever

    def check_model_with_retriever(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes:
            model = model_class(config, retriever=self.get_retriever(config))

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))

    def check_model_generate_from_context_input_ids(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for i, model_class in enumerate(self.all_generative_model_classes):
            model = model_class(config)
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.numpy(),
                prefix=config.generator.prefix,
                return_tensors="tf",
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )
            retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32)

            # compute doc_scores
            doc_scores = tf.squeeze(
                tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True),
                axis=[1],
            )

            outputs = model.generate(
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
            )

            self.assertIsNotNone(outputs)

    def check_model_generate(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_generative_model_classes:
            model = model_class(config, retriever=self.get_retriever(config))

            self.assertTrue(model.config.is_encoder_decoder)

            input_ids = tf.cast(input_ids, tf.int32)
            outputs = model.generate(
                input_ids=input_ids,
                num_beams=2,
                num_return_sequences=2,
                decoder_start_token_id=config.generator.eos_token_id,
                max_new_tokens=5,
            )

            self.assertIsNotNone(outputs)

    def check_model_without_retriever(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config)
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.numpy(),
                prefix=config.generator.prefix,
                return_tensors="tf",
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32)

            # compute doc_scores
            doc_scores = tf.squeeze(
                tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True),
                axis=[1],
            )

            outputs = model(
                input_ids=None,
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
            )

            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))

    def check_model_custom_n_docs(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, n_docs, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config)
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.numpy(),
                prefix=config.generator.prefix,
                return_tensors="tf",
                n_docs=n_docs,
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32)

            # compute doc_scores
            doc_scores = tf.squeeze(
                tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True),
                axis=[1],
            )

            outputs = model(
                input_ids=None,
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                n_docs=n_docs,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], n_docs))

    def check_model_with_mismatch_n_docs_value(
        self,
        config,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        retriever_n_docs,
        generator_n_docs,
        **kwargs,
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        retriever = self.get_retriever(config)

        for model_class in self.all_model_classes:
            model = model_class(config)
            self.assertTrue(model.config.is_encoder_decoder)

            question_hidden_states = model.question_encoder(input_ids, attention_mask=attention_mask)[0]

            out = retriever(
                input_ids,
                question_hidden_states.numpy(),
                prefix=config.generator.prefix,
                return_tensors="tf",
                n_docs=retriever_n_docs,
            )

            context_input_ids, context_attention_mask, retrieved_doc_embeds = (
                out["context_input_ids"],
                out["context_attention_mask"],
                out["retrieved_doc_embeds"],
            )

            retrieved_doc_embeds = tf.cast(retrieved_doc_embeds, tf.float32)

            # compute doc_scores
            doc_scores = tf.squeeze(
                tf.matmul(tf.expand_dims(question_hidden_states, axis=[1]), retrieved_doc_embeds, transpose_b=True),
                axis=[1],
            )

            self.assertRaises(
                AssertionError,
                model.__call__,
                input_ids=None,
                context_input_ids=context_input_ids,
                context_attention_mask=context_attention_mask,
                doc_scores=doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                n_docs=generator_n_docs,
            )

    def check_model_with_encoder_outputs(
        self, config, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, **kwargs
    ):
        self.assertIsNotNone(config.question_encoder)
        self.assertIsNotNone(config.generator)

        for model_class in self.all_model_classes:
            model = model_class(config, retriever=self.get_retriever(config))

            self.assertTrue(model.config.is_encoder_decoder)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            encoder_outputs = TFBaseModelOutput(outputs.generator_enc_last_hidden_state)

            # run only generator
            outputs = model(
                input_ids=None,
                encoder_outputs=encoder_outputs,
                doc_scores=outputs.doc_scores,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )

            # logits
            self.assertEqual(
                outputs.logits.shape,
                (self.n_docs * decoder_input_ids.shape[0], decoder_input_ids.shape[1], config.generator.vocab_size),
            )
            # generator encoder last hidden states
            self.assertEqual(
                outputs.generator_enc_last_hidden_state.shape,
                (self.n_docs * decoder_input_ids.shape[0], self.max_combined_length, config.generator.hidden_size),
            )
            # doc scores
            self.assertEqual(outputs.doc_scores.shape, (input_ids.shape[0], self.n_docs))

    def test_model_with_retriever(self):
        inputs_dict = self.config_and_inputs
        self.check_model_with_retriever(**inputs_dict)

    def test_model_without_retriever(self):
        inputs_dict = self.config_and_inputs
        self.check_model_without_retriever(**inputs_dict)

    @slow
    def test_model_generate_from_context_input_ids(self):
        inputs_dict = self.config_and_inputs
        self.check_model_generate_from_context_input_ids(**inputs_dict)

    def test_model_with_encoder_outputs(self):
        inputs_dict = self.config_and_inputs
        self.check_model_with_encoder_outputs(**inputs_dict)

    @slow
    def test_model_generate(self):
        inputs_dict = self.config_and_inputs
        self.check_model_generate(**inputs_dict)

    def test_model_with_custom_n_docs(self):
        inputs_dict = self.config_and_inputs
        inputs_dict["n_docs"] = 1
        self.check_model_custom_n_docs(**inputs_dict)

    def test_model_with_mismatch_n_docs_value(self):
        inputs_dict = self.config_and_inputs
        inputs_dict["retriever_n_docs"] = 3
        inputs_dict["generator_n_docs"] = 2
        self.check_model_with_mismatch_n_docs_value(**inputs_dict)


@require_tf
@require_retrieval
class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase):
    @cached_property
    def config_and_inputs(self):
        question_encoder_tester = TFDPRModelTester(self)
        dpr_config_and_inputs = question_encoder_tester.prepare_config_and_inputs()
        generator_tester = TFBartModelTester(self)
        bart_config_and_inputs = generator_tester.prepare_config_and_inputs_for_common()

        (question_encoder_config, input_ids, _, input_mask, _, _, _) = dpr_config_and_inputs
        (generator_config, bart_inputs_dict) = bart_config_and_inputs
        decoder_input_ids, decoder_attention_mask = bart_inputs_dict["input_ids"], bart_inputs_dict["attention_mask"]

        config = RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            n_docs=self.n_docs,
            retrieval_vector_size=self.retrieval_vector_size,
            max_combined_length=self.max_combined_length,
        )

        return {
            "config": config,
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "decoder_input_ids": decoder_input_ids,
            "decoder_attention_mask": decoder_attention_mask,
        }


@require_tf
@require_retrieval
@require_sentencepiece
@require_tokenizers
class TFRagModelIntegrationTests(unittest.TestCase):
    @cached_property
    def token_model(self):
        return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
            "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn"
        )

    @cached_property
    def sequence_model(self):
        return TFRagSequenceForGeneration.from_pretrained_question_encoder_generator(
            "facebook/dpr-question_encoder-single-nq-base", "facebook/bart-large-cnn"
        )

    def token_model_nq_checkpoint(self, retriever):
        return TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

    def get_rag_config(self):
        question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
        return RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            bos_token_id=0,
            decoder_start_token_id=2,
            eos_token_id=2,
            is_encoder_decoder=True,
            pad_token_id=1,
            vocab_size=50264,
            title_sep=" / ",
            doc_sep=" // ",
            n_docs=5,
            max_combined_length=300,
            dataset="wiki_dpr",
            dataset_split="train",
            index_name="exact",
            index_path=None,
            use_dummy_dataset=True,
            retrieval_vector_size=768,
            retrieval_batch_size=8,
            dataset_revision="b24a417",
        )

    @slow
    def test_rag_sequence_inference(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_sequence = self.sequence_model
        rag_sequence.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="tf"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids

        output = rag_sequence(
            input_ids,
            labels=decoder_input_ids,
        )

        expected_shape = tf.TensorShape([5, 5, 50264])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]])
        expected_loss = tf.convert_to_tensor([36.7368])

        tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3)
        tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)

    @slow
    def test_rag_token_inference(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model
        rag_token.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="tf"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids

        output = rag_token(
            input_ids,
            labels=decoder_input_ids,
        )

        expected_shape = tf.TensorShape([5, 5, 50264])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]])
        expected_loss = tf.convert_to_tensor([36.3557])

        tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3)
        tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)

    @slow
    def test_rag_token_inference_nq_checkpoint(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model_nq_checkpoint(retriever=rag_retriever)

        # check that outputs after saving and loading are equal
        with tempfile.TemporaryDirectory() as tmpdirname:
            rag_token.save_pretrained(tmpdirname)
            rag_token = TFRagTokenForGeneration.from_pretrained(tmpdirname, retriever=rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="tf"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids

        output = rag_token(
            input_ids,
            labels=decoder_input_ids,
        )

        expected_shape = tf.TensorShape([5, 5, 50265])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = tf.convert_to_tensor([[62.9402, 62.7107, 62.2382, 62.1194, 61.8578]])
        expected_loss = tf.convert_to_tensor([32.521812])

        tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3)
        tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)

    @slow
    def test_rag_token_inference_save_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_token = self.token_model
        rag_token.set_retriever(rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="tf"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids

        # model must run once to be functional before loading/saving works
        rag_token(
            input_ids,
            labels=decoder_input_ids,
        )

        # check that outputs after saving and loading are equal
        with tempfile.TemporaryDirectory() as tmpdirname:
            rag_token.save_pretrained(tmpdirname)
            rag_token = TFRagTokenForGeneration.from_pretrained(tmpdirname, retriever=rag_retriever)

        output = rag_token(
            input_ids,
            labels=decoder_input_ids,
        )

        expected_shape = tf.TensorShape([5, 5, 50264])
        self.assertEqual(output.logits.shape, expected_shape)

        expected_doc_scores = tf.convert_to_tensor([[75.0286, 74.4998, 74.0804, 74.0306, 73.9504]])
        expected_loss = tf.convert_to_tensor([36.3557])

        tf.debugging.assert_near(output.loss, expected_loss, atol=1e-3)
        tf.debugging.assert_near(output.doc_scores, expected_doc_scores, atol=1e-3)

    @slow
    def test_init_and_from_pretrained(self):
        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        rag_config = RagConfig.from_pretrained("facebook/rag-sequence-base")
        rag = TFRagTokenForGeneration(rag_config, retriever=rag_retriever)

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="tf"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids

        rag(
            input_ids,
            decoder_input_ids=decoder_input_ids,
        )

        # this should not give any warnings
        with tempfile.TemporaryDirectory() as tmpdirname:
            rag.save_pretrained(tmpdirname)
            rag = TFRagTokenForGeneration.from_pretrained(tmpdirname, retriever=rag_retriever)

    @property
    def test_data_questions(self):
        return [
            "who got the first nobel prize in physics",
            "when is the next deadpool movie being released",
            "which mode is used for short wave broadcast service",
            "who is the owner of reading football club",
            "when is the next scandal episode coming out",
            "when is the last time the philadelphia won the superbowl",
            "what is the most current adobe flash player version",
            "how many episodes are there in dragon ball z",
        ]

    @slow
    def test_rag_token_greedy_search(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417"
        )
        rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

        # check first two questions
        input_dict = tokenizer(
            self.test_data_questions[:2],
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        # make sure only 1 beam is used
        rag_token.config.num_beams = 1

        output_ids = rag_token.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)

    @slow
    def test_rag_token_generate_batch(self):
        # NOTE: gold labels comes from num_beam=4, so this is effectively beam-search test
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417"
        )
        rag_token = TFRagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " september 22, 2017",
            " amplitude modulation",
            " stefan persson",
            " april 20, 2018",
            " the 1970s",
            " 7.1. 2",
            " 13",
        ]

        # Split into 2 batches of 4 examples to avoid GPU OOM.
        output_ids = rag_token.generate(
            input_ids[:4],
            attention_mask=attention_mask[:4],
        )
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(outputs, EXPECTED_OUTPUTS[:4])

        output_ids = rag_token.generate(
            input_ids[4:],
            attention_mask=attention_mask[4:],
        )
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        self.assertListEqual(outputs, EXPECTED_OUTPUTS[4:])

    @slow
    def test_rag_sequence_generate_batch(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq",
            index_name="exact",
            use_dummy_dataset=True,
            dataset_revision="b24a417",
        )
        rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)

        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids
        attention_mask = input_dict.attention_mask

        output_ids = rag_sequence.generate(
            input_ids,
            attention_mask=attention_mask,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)

    @slow
    def test_rag_sequence_generate_batch_from_context_input_ids(self):
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
        retriever = RagRetriever.from_pretrained(
            "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True, dataset_revision="b24a417"
        )
        rag_sequence = TFRagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
        input_dict = tokenizer(
            self.test_data_questions,
            return_tensors="tf",
            padding=True,
            truncation=True,
        )

        input_ids = input_dict.input_ids

        question_hidden_states = rag_sequence.question_encoder(input_ids)[0]
        docs_dict = retriever(input_ids.numpy(), question_hidden_states.numpy(), return_tensors="tf")
        doc_scores = tf.squeeze(
            tf.matmul(
                tf.expand_dims(question_hidden_states, axis=[1]), docs_dict["retrieved_doc_embeds"], transpose_b=True
            ),
            axis=[1],
        )
        output_ids = rag_sequence.generate(
            context_input_ids=docs_dict["context_input_ids"],
            context_attention_mask=docs_dict["context_attention_mask"],
            doc_scores=doc_scores,
            do_deduplication=True,
        )

        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

        EXPECTED_OUTPUTS = [
            " albert einstein",
            " june 22, 2018",
            " amplitude modulation",
            " tim besley ( chairman )",
            " june 20, 2018",
            " 1980",
            " 7.0",
            " 8",
        ]
        self.assertListEqual(outputs, EXPECTED_OUTPUTS)


@require_tf
@require_retrieval
class TFRagModelSaveLoadTests(unittest.TestCase):
    def get_rag_config(self):
        question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
        return RagConfig.from_question_encoder_generator_configs(
            question_encoder_config,
            generator_config,
            bos_token_id=0,
            decoder_start_token_id=2,
            eos_token_id=2,
            is_encoder_decoder=True,
            pad_token_id=1,
            vocab_size=50264,
            title_sep=" / ",
            doc_sep=" // ",
            n_docs=5,
            max_combined_length=300,
            dataset="wiki_dpr",
            dataset_split="train",
            index_name="exact",
            index_path=None,
            use_dummy_dataset=True,
            retrieval_vector_size=768,
            retrieval_batch_size=8,
            dataset_revision="b24a417",
        )

    @slow
    def test_rag_sequence_from_pretrained(self):
        load_weight_prefix = "tf_rag_model_1"

        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="tf"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids

        with tempfile.TemporaryDirectory() as tmp_dirname:
            rag_sequence = TFRagSequenceForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base",
                "facebook/bart-large-cnn",
                retriever=rag_retriever,
                config=rag_config,
            )
            rag_sequence.build_in_name_scope()
            # check that the from pretrained methods work
            rag_sequence.save_pretrained(tmp_dirname)
            rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever)

            output = rag_sequence(input_ids, labels=decoder_input_ids)

            loss_pretrained = output.loss
            del rag_sequence

        question_encoder = TFAutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator = TFAutoModelForSeq2SeqLM.from_pretrained(
            "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator"
        )

        rag_sequence = TFRagSequenceForGeneration(
            config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
        )

        output = rag_sequence(input_ids, labels=decoder_input_ids)

        loss_init = output.loss

        self.assertAlmostEqual(loss_pretrained, loss_init, places=4)

    @slow
    def test_rag_token_from_pretrained(self):
        load_weight_prefix = "tf_rag_model_1"

        rag_config = self.get_rag_config()
        rag_decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
        rag_question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base"
        )
        rag_retriever = RagRetriever(
            rag_config,
            question_encoder_tokenizer=rag_question_encoder_tokenizer,
            generator_tokenizer=rag_decoder_tokenizer,
        )

        input_ids = rag_question_encoder_tokenizer(
            "who sings does he love me with reba", return_tensors="tf"
        ).input_ids
        decoder_input_ids = rag_decoder_tokenizer("Linda Davis", return_tensors="tf").input_ids

        with tempfile.TemporaryDirectory() as tmp_dirname:
            rag_token = TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
                "facebook/dpr-question_encoder-single-nq-base",
                "facebook/bart-large-cnn",
                retriever=rag_retriever,
                config=rag_config,
            )
            rag_token.build_in_name_scope()
            # check that the from pretrained methods work
            rag_token.save_pretrained(tmp_dirname)
            rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)

            output = rag_token(input_ids, labels=decoder_input_ids)

            loss_pretrained = output.loss
            del rag_token

        question_encoder = TFAutoModel.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        generator = TFAutoModelForSeq2SeqLM.from_pretrained(
            "facebook/bart-large-cnn", load_weight_prefix=load_weight_prefix, name="generator"
        )
        rag_token = TFRagTokenForGeneration(
            config=rag_config, question_encoder=question_encoder, generator=generator, retriever=rag_retriever
        )

        output = rag_token(input_ids, labels=decoder_input_ids)

        loss_init = output.loss

        self.assertAlmostEqual(loss_pretrained, loss_init, places=4)
