# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from fairseq.data import MonolingualDataset
from fairseq.tasks.language_modeling import LanguageModelingConfig, LanguageModelingTask
from tests import utils as test_utils


class TestLMContextWindow(unittest.TestCase):
    def test_eval_dataloader(self):
        dictionary = test_utils.dummy_dictionary(10)
        assert len(dictionary) == 14  # 4 extra special symbols
        assert dictionary.pad() == 1

        dataset = test_utils.TestDataset(
            [
                torch.tensor([4, 5, 6, 7], dtype=torch.long),
                torch.tensor([8, 9, 10, 11], dtype=torch.long),
                torch.tensor([12, 13], dtype=torch.long),
            ]
        )
        dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary)

        config = LanguageModelingConfig(tokens_per_sample=4)
        task = LanguageModelingTask(config, dictionary)

        eval_dataloader = task.eval_lm_dataloader(
            dataset=dataset,
            batch_size=1,
            context_window=2,
            num_workers=0,
        )

        batch = next(eval_dataloader)
        assert batch["net_input"]["src_tokens"][0].tolist() == [4, 5, 6, 7, 1, 1]
        assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1]

        batch = next(eval_dataloader)
        assert batch["net_input"]["src_tokens"][0].tolist() == [6, 7, 8, 9, 10, 11]
        assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11]

        batch = next(eval_dataloader)
        assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13]
        assert batch["target"][0].tolist() == [1, 1, 12, 13]


if __name__ == "__main__":
    unittest.main()
