# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import unittest

from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models.transformer import TransformerModel
from tests.test_sequence_generator import get_dummy_task_and_parser


class TestInferenceDropout(unittest.TestCase):
    def setUp(self):
        self.task, self.parser = get_dummy_task_and_parser()
        TransformerModel.add_args(self.parser)
        self.args = self.parser.parse_args([])
        self.args.encoder_layers = 2
        self.args.decoder_layers = 1
        logging.disable(logging.CRITICAL)

    def tearDown(self):
        logging.disable(logging.NOTSET)

    def test_sets_inference_dropout_to_true(self):
        self.args.retain_dropout = True
        self.transformer_model = TransformerModel.build_model(self.args, self.task)
        cfg = convert_namespace_to_omegaconf(self.args)
        self.transformer_model.prepare_for_inference_(cfg)
        assert self.transformer_model.encoder.dropout_module.apply_during_inference
        assert self.transformer_model.decoder.dropout_module.apply_during_inference
        for layer in self.transformer_model.encoder.layers:
            assert layer.dropout_module.apply_during_inference

    def test_inference_dropout_false_by_default(self):
        self.transformer_model = TransformerModel.build_model(self.args, self.task)
        cfg = convert_namespace_to_omegaconf(self.args)
        self.transformer_model.prepare_for_inference_(cfg)
        assert not self.transformer_model.encoder.dropout_module.apply_during_inference
        assert not self.transformer_model.decoder.dropout_module.apply_during_inference
        for layer in self.transformer_model.encoder.layers:
            assert not layer.dropout_module.apply_during_inference
        for layer in self.transformer_model.decoder.layers:
            assert not layer.dropout_module.apply_during_inference

    def test_applies_training_mode(self):
        self.transformer_model = TransformerModel.build_model(self.args, self.task)
        assert self.transformer_model.encoder.dropout_module.training
        for layer in self.transformer_model.encoder.layers:
            assert layer.dropout_module.training

        self.transformer_model.eval()
        assert not self.transformer_model.decoder.dropout_module.training
        for layer in self.transformer_model.encoder.layers:
            assert not layer.dropout_module.training

    def test_retain_modules(self):
        self.args.retain_dropout = True
        self.args.retain_dropout_modules = [
            "TransformerEncoder",
            "TransformerEncoderLayer",
        ]
        self.transformer_model = TransformerModel.build_model(self.args, self.task)
        cfg = convert_namespace_to_omegaconf(self.args)
        self.transformer_model.prepare_for_inference_(cfg)
        assert self.transformer_model.encoder.dropout_module.apply_during_inference
        assert not self.transformer_model.decoder.dropout_module.apply_during_inference
        for layer in self.transformer_model.decoder.layers:
            assert not layer.dropout_module.apply_during_inference
