# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import torch
import deepspeed
import pytest
import random
import numpy as np
import deepspeed.comm as dist
from unit.common import DistributedTest, DistributedFixture
from unit.megatron_model import get_megatron_version
from unit.megatron_model import MockGPT2ModelPipe as GPT2ModelPipe
from deepspeed.utils import RepeatingLoader
from deepspeed.accelerator import get_accelerator
from unit.util import required_minimum_torch_version, required_maximum_torch_version

pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=5),
                                reason='Megatron-LM package requires Pytorch version 1.5 or above')
pytestmark = pytest.mark.skipif(not required_maximum_torch_version(major_version=1, minor_version=13),
                                reason='Megatron-LM package requires Pytorch version 1.13 or below')


def get_deepspeed_model(model):
    ds_config_dict = {
        "train_micro_batch_size_per_gpu": 1,
        "optimizer": {
            "type": "Lamb",
            "params": {
                "lr": 0.00015
            }
        },
    }

    model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=ds_config_dict)
    return model.to(get_accelerator().device_name())


def get_topology(mp, pp, world_size):
    assert world_size % (pp * mp) == 0
    dp = world_size // (pp * mp)

    from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
    topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)

    return topo


class ConfigurablePP(DistributedTest):

    @pytest.fixture(autouse=True)
    def reset_random(self, seed=1234):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        get_accelerator().manual_seed_all(seed)

    @pytest.fixture
    def inputs(self, bs=1, seq_len=1, hidden_size=128):
        hidden_states = torch.randn(bs, seq_len, hidden_size)
        attention_mask = torch.randint(low=0, high=2, size=(bs, seq_len), dtype=torch.bool)
        return (hidden_states, attention_mask)


class TestConfigurablePP(ConfigurablePP):
    mp_size = 2
    pp_size = 2
    world_size = 4  # mp_size * pp_size

    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_pp_basic(self, inputs, tmpdir):
        # basic test case, mp_size=2, pp_size=2, verify ckpt saving/loading.
        args_defaults = {
            'num_layers': 8,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }
        mp_size = self.mp_size
        pp_size = self.pp_size
        world_size = self.world_size

        topo = get_topology(mp_size, pp_size, world_size)
        gpt2_pipe_model = GPT2ModelPipe(num_layers=8,
                                        num_stages=pp_size,
                                        mp_size=mp_size,
                                        args_others=args_defaults,
                                        topo=topo)
        model = get_deepspeed_model(gpt2_pipe_model)

        tag = 'pp_basic'
        state_dict = {}
        state_dict['checkpoint_version'] = get_megatron_version()
        model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)

        if model.is_first_stage() or model.is_last_stage():
            loader = RepeatingLoader([(inputs[0], 0)])
            data_iter = iter(loader)
        else:
            data_iter = None

        baseline = model.eval_batch(data_iter=data_iter, compute_loss=False, reduce_output=None)

        dist.barrier()
        model.load_checkpoint(tmpdir, tag=tag, load_optimizer_states=False, load_lr_scheduler_states=False)
        dist.barrier()

        test = model.eval_batch(data_iter=data_iter, compute_loss=False, reduce_output=None)

        if test is not None:
            assert len(baseline) == len(test)
            # Compare outputs of each microbatch
            for mb in range(len(baseline)):
                for b, t in zip(baseline[mb], test[mb]):
                    if b.is_floating_point():  # don't compare masks
                        assert torch.allclose(
                            b, t,
                            atol=1e-07), f"Baseline output {baseline} is not equal to save-then-load output {test}"


# Fixture for defining the checkpoint path since all tests in
# TestConfigurableResizePP will use the same tmpdir
@pytest.fixture
def checkpoint_tag(mp_size, pp_size, mp_resize, pp_resize):
    return f"{mp_size}-{pp_size}-{mp_resize}-{pp_resize}"


# Base class for creating / saving model output for baseline models. This is
# not meant to be used directly as a fixture to any classes
class _baseline(DistributedFixture):
    world_size = None

    def run(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size):
        assert int(os.environ["WORLD_SIZE"]) == (pp_size *
                                                 mp_size), "world size does not match provided pp_size and mp_size"
        args_defaults = {
            'num_layers': 8,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }

        topo = get_topology(mp_size, pp_size, mp_size * pp_size)
        gpt2_pipe_model = GPT2ModelPipe(num_layers=8,
                                        num_stages=pp_size,
                                        mp_size=mp_size,
                                        args_others=args_defaults,
                                        topo=topo)
        model = get_deepspeed_model(gpt2_pipe_model)

        with torch.no_grad():
            inputs = [x.to(get_accelerator().device_name()) for x in inputs]
            if model.is_first_stage() or model.is_last_stage():
                loader = RepeatingLoader([(inputs[0], 0)])
                data_iter = iter(loader)
            else:
                data_iter = None

            baseline = model.eval_batch(data_iter=data_iter, compute_loss=False, reduce_output=None)

            if baseline is not None:
                # baseline should be [[hidden, True]]]
                assert len(baseline) == 1
                assert len(baseline[0]) == 1
                assert torch.is_tensor(baseline[0][0])
                save_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
                torch.save(baseline[0][0].cpu(), save_path)

            state_dict = {}
            state_dict['checkpoint_version'] = get_megatron_version()
            model.save_checkpoint(class_tmpdir, tag=checkpoint_tag, client_state=state_dict)


# This may look odd, but there is a limitation with DistributedFixture that
# doesn't allow us to reuse a fixture with different worldsizes. This could be
# implemented in conftest.py::pytest_fixture_setup and common.py::DistributedFixture
class baseline_ws1(_baseline):
    world_size = 1


class baseline_ws2(_baseline):
    world_size = 2


class baseline_ws4(_baseline):
    world_size = 4


class TestConfigurableResizePP(ConfigurablePP):

    def _test(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize):
        args_defaults = {
            'num_layers': 8,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }

        topo = get_topology(mp_resize, pp_resize, mp_resize * pp_resize)
        gpt2_pipe_model = GPT2ModelPipe(num_layers=8,
                                        num_stages=pp_resize,
                                        mp_size=mp_resize,
                                        args_others=args_defaults,
                                        topo=topo)
        model = get_deepspeed_model(gpt2_pipe_model)

        with torch.no_grad():
            model.load_checkpoint(class_tmpdir,
                                  tag=checkpoint_tag,
                                  load_optimizer_states=False,
                                  load_lr_scheduler_states=False)
            inputs = [x.to(get_accelerator().device_name()) for x in inputs]
            if model.is_first_stage() or model.is_last_stage():
                loader = RepeatingLoader([(inputs[0], 0)])
                data_iter = iter(loader)
            else:
                data_iter = None

            test = model.eval_batch(data_iter=data_iter, compute_loss=False, reduce_output=None)

            if test is not None:
                # test should be [[hidden, True]]]
                assert len(test) == 1
                assert len(test[0]) == 1
                assert torch.is_tensor(test[0][0])
                test = test[0][0].cpu()
                load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
                baseline = torch.load(load_path)
                assert torch.allclose(
                    baseline, test,
                    atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"

    # These tests are divided by baseline model worldsize and test model worldsize
    @pytest.mark.world_size(1)
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 2, 1, 1)])
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_2to1(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws2, mp_size, pp_size, mp_resize,
                             pp_resize):
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)

    @pytest.mark.world_size(1)
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(2, 2, 1, 1)])
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_4to1(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws4, mp_size, pp_size, mp_resize,
                             pp_resize):
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)

    @pytest.mark.world_size(2)
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(2, 2, 2, 1)])
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_4to2(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws4, mp_size, pp_size, mp_resize,
                             pp_resize):
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)

    @pytest.mark.world_size(4)
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 1, 2, 2)])
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_1to4(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws1, mp_size, pp_size, mp_resize,
                             pp_resize):
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)

    @pytest.mark.world_size(4)
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 2, 1, 4), (2, 1, 2, 2)])
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_2to4(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws2, mp_size, pp_size, mp_resize,
                             pp_resize):
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)
