# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
from io import StringIO
import unittest
from unittest.mock import MagicMock, patch

import torch

from fairseq import data, checkpoint_utils


def mock_trainer(epoch, num_updates, iterations_in_epoch):
    trainer = MagicMock()
    trainer.load_checkpoint.return_value = {
        'train_iterator': {
            'epoch': epoch,
            'iterations_in_epoch': iterations_in_epoch,
            'shuffle': False,
        },
    }
    trainer.get_num_updates.return_value = num_updates
    return trainer


def mock_dict():
    d = MagicMock()
    d.pad.return_value = 1
    d.eos.return_value = 2
    d.unk.return_value = 3
    return d


def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
    tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
    tokens_ds = data.TokenBlockDataset(
        tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
    )
    trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
    dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
    epoch_itr = data.EpochBatchIterator(
        dataset=dataset,
        collate_fn=dataset.collater,
        batch_sampler=[[i] for i in range(epoch_size)],
    )
    return trainer, epoch_itr


def get_mock_args(finetune_from_model=None):
    args_mock = MagicMock()
    args_mock.optimizer_overrides = '{}'
    args_mock.reset_dataloader = False
    args_mock.reset_meters = False
    args_mock.reset_optimizer = False
    args_mock.reset_lr_scheduler = False
    args_mock.finetune_from_model = finetune_from_model
    args_mock.model_parallel_size = 1
    return args_mock


class TestLoadCheckpoint(unittest.TestCase):

    def setUp(self):
        self.args_mock = get_mock_args()
        self.patches = {
            'os.makedirs': MagicMock(),
            'os.path.join': MagicMock(),
            'os.path.isfile': MagicMock(return_value=True),
            'os.path.isabs': MagicMock(return_value=False),
            'fairseq.file_io.PathManager.exists': MagicMock(return_value=False),
        }
        self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
        [p.start() for p in self.applied_patches]

    def test_load_partial_checkpoint(self):
        with contextlib.redirect_stdout(StringIO()):
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)

            _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)

            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertEqual(epoch_itr.epoch, 2)
            self.assertEqual(epoch_itr.iterations_in_epoch, 50)

            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
            self.assertEqual(epoch_itr.iterations_in_epoch, 51)

            for _ in range(150 - 52):
                next(itr)
            self.assertEqual(epoch_itr.iterations_in_epoch, 149)
            self.assertTrue(itr.has_next())
            next(itr)
            self.assertFalse(itr.has_next())

            itr = epoch_itr.next_epoch_itr(shuffle=False)
            self.assertTrue(itr.has_next())
            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)

    def test_load_full_checkpoint(self):
        with contextlib.redirect_stdout(StringIO()):
            trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)

            _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 3)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)

    def test_load_no_checkpoint(self):
        with contextlib.redirect_stdout(StringIO()):
            trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
            self.patches['os.path.isfile'].return_value = False

            _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
            itr = epoch_itr.next_epoch_itr(shuffle=False)

            self.assertEqual(epoch_itr.epoch, 1)
            self.assertEqual(epoch_itr.iterations_in_epoch, 0)
            self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)

    def test_finetune_from_model_args_conflict(self):
        with contextlib.redirect_stdout(StringIO()):
            trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)

            for arg in ['reset_optimizer', 'reset_lr_scheduler', 'reset_meters', 'reset_dataloader']:
                with self.subTest(arg=arg):
                    args_mock = get_mock_args("/temp/checkpoint_pretrained.pt")
                    setattr(args_mock, arg, True)
                    with self.assertRaises(Exception) as context:
                        _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)

                    self.assertTrue(
                        "--finetune-from-model can not be set together with either --reset-optimizer"
                        " or reset_lr_scheduler or reset_meters or reset_dataloader" in str(context.exception)
                    )

    def test_finetune_from_model(self):
        with contextlib.redirect_stdout(StringIO()):
            trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
            from_model_path = "/temp/checkpoint_pretrained.pt"
            args_mock = get_mock_args(from_model_path)
            args_mock.restore_file = "checkpoint_last.pt"

            def mock_finetune_exist(path):
                if path == from_model_path:
                    return True
                else:
                    return False
            self.patches['fairseq.file_io.PathManager.exists'].side_effect = mock_finetune_exist
            _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
            checkpoint_path, reset_optimizer, reset_lr_scheduler, \
                optimizer_overrides = trainer.load_checkpoint.call_args[0]
            reset_meters = trainer.load_checkpoint.call_args[1]['reset_meters']
            self.assertTrue(reset_optimizer)
            self.assertTrue(reset_lr_scheduler)
            self.assertTrue(reset_meters)

    def test_finetune_from_model_resume(self):
        with contextlib.redirect_stdout(StringIO()):
            trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
            trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
            from_model_path = "/temp/checkpoint_pretrained.pt"
            args_mock = get_mock_args(from_model_path)
            args_mock.restore_file = "checkpoint_last.pt"

            # launch second time
            # both restore_file=checkpoint_last.pt and finetune_from_model are set
            def mock_finetune_exist(path):
                if path == from_model_path or path.endsWith('checkpoint_last.pt'):
                    return True
                else:
                    return False
            self.patches['fairseq.file_io.PathManager.exists'].side_effect = mock_finetune_exist
            _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
            checkpoint_path, reset_optimizer, reset_lr_scheduler, \
                optimizer_overrides = trainer.load_checkpoint.call_args[0]
            reset_meters = trainer.load_checkpoint.call_args[1]['reset_meters']
            self.assertFalse(reset_optimizer)
            self.assertFalse(reset_lr_scheduler)
            self.assertFalse(reset_meters)

    def tearDown(self):
        patch.stopall()


if __name__ == '__main__':
    unittest.main()
