# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
from unittest import TestCase
from unittest.mock import patch

import numpy as np
from datasets import Dataset

from transformers.models.realm.configuration_realm import RealmConfig
from transformers.models.realm.retrieval_realm import _REALM_BLOCK_RECORDS_FILENAME, RealmRetriever
from transformers.models.realm.tokenization_realm import VOCAB_FILES_NAMES, RealmTokenizer


class RealmRetrieverTest(TestCase):
    def setUp(self):
        self.tmpdirname = tempfile.mkdtemp()
        self.num_block_records = 5

        # Realm tok
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "[PAD]",
            "[MASK]",
            "test",
            "question",
            "this",
            "is",
            "the",
            "first",
            "second",
            "third",
            "fourth",
            "fifth",
            "record",
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
        ]
        realm_tokenizer_path = os.path.join(self.tmpdirname, "realm_tokenizer")
        os.makedirs(realm_tokenizer_path, exist_ok=True)
        self.vocab_file = os.path.join(realm_tokenizer_path, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

        realm_block_records_path = os.path.join(self.tmpdirname, "realm_block_records")
        os.makedirs(realm_block_records_path, exist_ok=True)

    def get_tokenizer(self) -> RealmTokenizer:
        return RealmTokenizer.from_pretrained(os.path.join(self.tmpdirname, "realm_tokenizer"))

    def tearDown(self):
        shutil.rmtree(self.tmpdirname)

    def get_config(self):
        config = RealmConfig(num_block_records=self.num_block_records)
        return config

    def get_dummy_dataset(self):
        dataset = Dataset.from_dict(
            {
                "id": ["0", "1"],
                "question": ["foo", "bar"],
                "answers": [["Foo", "Bar"], ["Bar"]],
            }
        )
        return dataset

    def get_dummy_block_records(self):
        block_records = np.array(
            [
                b"This is the first record",
                b"This is the second record",
                b"This is the third record",
                b"This is the fourth record",
                b"This is the fifth record",
                b"This is a longer longer longer record",
            ],
            dtype=object,
        )
        return block_records

    def get_dummy_retriever(self):
        retriever = RealmRetriever(
            block_records=self.get_dummy_block_records(),
            tokenizer=self.get_tokenizer(),
        )
        return retriever

    def test_retrieve(self):
        config = self.get_config()
        retriever = self.get_dummy_retriever()
        tokenizer = retriever.tokenizer

        retrieved_block_ids = np.array([0, 3], dtype="long")
        question_input_ids = tokenizer(["Test question"]).input_ids
        answer_ids = tokenizer(
            ["the fourth"],
            add_special_tokens=False,
            return_token_type_ids=False,
            return_attention_mask=False,
        ).input_ids
        max_length = config.reader_seq_len

        has_answers, start_pos, end_pos, concat_inputs = retriever(
            retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
        )

        self.assertEqual(len(has_answers), 2)
        self.assertEqual(len(start_pos), 2)
        self.assertEqual(len(end_pos), 2)
        self.assertEqual(concat_inputs.input_ids.shape, (2, 10))
        self.assertEqual(concat_inputs.attention_mask.shape, (2, 10))
        self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10))
        self.assertEqual(concat_inputs.special_tokens_mask.shape, (2, 10))
        self.assertEqual(
            tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]),
            ["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"],
        )
        self.assertEqual(
            tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[1]),
            ["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "fourth", "record", "[SEP]"],
        )

    def test_block_has_answer(self):
        config = self.get_config()
        retriever = self.get_dummy_retriever()
        tokenizer = retriever.tokenizer

        retrieved_block_ids = np.array([0, 3, 5], dtype="long")
        question_input_ids = tokenizer(["Test question"]).input_ids
        answer_ids = tokenizer(
            ["the fourth", "longer longer"],
            add_special_tokens=False,
            return_token_type_ids=False,
            return_attention_mask=False,
        ).input_ids
        max_length = config.reader_seq_len

        has_answers, start_pos, end_pos, _ = retriever(
            retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
        )

        self.assertEqual([False, True, True], has_answers)
        self.assertEqual([[-1, -1, -1], [6, -1, -1], [6, 7, 8]], start_pos)
        self.assertEqual([[-1, -1, -1], [7, -1, -1], [7, 8, 9]], end_pos)

    def test_save_load_pretrained(self):
        retriever = self.get_dummy_retriever()
        retriever.save_pretrained(os.path.join(self.tmpdirname, "realm_block_records"))

        # Test local path
        retriever = retriever.from_pretrained(os.path.join(self.tmpdirname, "realm_block_records"))
        self.assertEqual(retriever.block_records[0], b"This is the first record")

        # Test mocked remote path
        with patch("transformers.models.realm.retrieval_realm.hf_hub_download") as mock_hf_hub_download:
            mock_hf_hub_download.return_value = os.path.join(
                os.path.join(self.tmpdirname, "realm_block_records"), _REALM_BLOCK_RECORDS_FILENAME
            )
            retriever = RealmRetriever.from_pretrained("google/realm-cc-news-pretrained-openqa")

        self.assertEqual(retriever.block_records[0], b"This is the first record")
