import torch
import unittest
from apex.transformer.testing import global_vars
from apex.transformer.testing.standalone_bert import bert_model_provider
from apex.transformer.pipeline_parallel.schedules.common import (
    _get_params_for_weight_decay_optimization, build_model
)
from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func
from apex.transformer.pipeline_parallel.utils import (
    average_losses_across_data_parallel_group, unwrap_model, setup_microbatch_calculator
)
from apex.transformer.log_util import set_logging_level
from apex.transformer import tensor_parallel, parallel_state
from apex.transformer.enums import ModelType
from apex.transformer._ucc_util import HAS_UCC
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase, NcclDistributedTestBase
import logging

from torch.testing._internal import common_utils

logging.getLogger("torch").setLevel(logging.WARNING)


logging.getLogger("apex").setLevel(logging.WARNING)


set_logging_level("WARNING")


class BertTestBase:

    def _download_fancy_data(self):
        text = """
    An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum.
    """
        text = text * 1024
        encoded = text.encode("ascii", "replace")
        ints = [int(encoded[i]) for i in range(len(encoded))]
        return torch.tensor(ints)

    # build a batch given sequence_len and batch size
    def _generate_fancy_data_labels(self, sequence_len, batch_size):
        temps = []
        for i in range(batch_size):
            if self.inds is None or self.data_idx >= len(self.inds):
                # hack as use of RNG will fall out of sync due to pipelines being different
                torch.manual_seed(self.MANUAL_SEED)
                self.inds = torch.randperm(
                    self.effective_length, device="cuda")
                self.masks = (
                    torch.rand(
                        len(self.inds) // batch_size + 1, batch_size, sequence_len, device="cuda"
                    )
                    >= self.MASK_PROB
                ).long()
                self.MANUAL_SEED += 1
                self.data_idx = 0
                if self.rank == 0:
                    print("new epoch", len(self.inds))
                    print("my start", self.inds[0:5])
                    print("masks_checksum:", torch.sum(self.masks))
            if self.EASY_MODE:
                data_idx_ = self.data_idx % self.EASY_MODE_SIZ
            else:
                data_idx_ = self.data_idx
            offset = self.inds[data_idx_]  # * SEQUENCE_LEN
            self.data_idx += 1

            curr = self.fancy_data[offset: offset +
                                   sequence_len].clone().detach()
            temps.append(curr)
        temp = torch.stack(temps, dim=0).cuda()
        mask = self.masks[self.data_idx // batch_size]
        mask_not = torch.logical_not(mask).long()
        data = mask * temp + mask_not * 124
        label = temp
        if parallel_state.get_tensor_model_parallel_rank() == 0:
            data_dict = {"text": data, "label": label, "mask_not": mask_not}
        else:
            data_dict = None
        keys = ["text", "label", "mask_not"]
        broadcasted_data = tensor_parallel.broadcast_data(
            keys, data_dict, torch.long)
        return (
            broadcasted_data["text"].long(),
            broadcasted_data["label"].long(),
            broadcasted_data["mask_not"],
        )

    def _fwd_step_func(self, batch, model):
        data, label, loss_mask = batch
        y = model(data, torch.ones_like(data), lm_labels=label)

        def loss_func(output_tensor):
            output_tensor, _ = output_tensor
            lm_loss_ = output_tensor.float()
            lm_loss = torch.sum(lm_loss_.view(-1) *
                                loss_mask.reshape(-1)) / loss_mask.sum()
            averaged_loss = average_losses_across_data_parallel_group([
                                                                      lm_loss])
            if self.data_idx >= 1536:
                # NOTE (patwang): Loss cutoff might be excessively high but roughly one in five
                # unlucky random seeds do cause loss to spike to just under 8.0
                self.assertLess(averaged_loss, 8.0)
            return lm_loss, {"avg": averaged_loss}

        return y, loss_func

    def _train(
        self, model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, async_comm
    ):
        args = global_vars.get_args()
        sequence_len = args.seq_length
        micro_batch_size = args.micro_batch_size
        hidden_size = args.hidden_size
        global_batch_size = args.global_batch_size
        forward_backward_func = get_forward_backward_func(
            virtual_pipeline_model_parallel_size, pipeline_model_parallel_size
        )
        tensor_shape = (sequence_len, micro_batch_size, hidden_size)
        for _ in range(16):
            batch = self._generate_fancy_data_labels(
                sequence_len, global_batch_size)
            optim.zero_grad()
            forward_backward_func(
                self._fwd_step_func,
                batch,
                model,
                forward_only=False,
                tensor_shape=tensor_shape,
                async_comm=async_comm,
                sequence_parallel_enabled=args.sequence_parallel,
            )
            # All-reduce layernorm parameters across model parallel nodes
            # when sequence parallelism is used
            if parallel_state.get_tensor_model_parallel_world_size() > 1 and args.sequence_parallel:
                for model_module in model:
                    unwrapped_model = unwrap_model(model_module)
                    for param in unwrapped_model.parameters():
                        if getattr(param, 'sequence_parallel_enabled', False):
                            grad = param.grad
                            torch.distributed.all_reduce(
                                grad, group=parallel_state.get_tensor_model_parallel_group())

            optim.step()

    @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus")
    def test_bert_without_interleaving(self):
        self._test_bert(virtual_pipeline_model_parallel_size=None)

    @unittest.skipUnless(torch.cuda.device_count() > 2, "requires at least 3 gpus")
    def test_bert_with_interleaving(self):
        if self.DISTRIBUTED_BACKEND == 'ucc':
            self.skipTest('skip interleaving with ucc')
        self._test_bert(virtual_pipeline_model_parallel_size=2)

    def _test_bert(self, virtual_pipeline_model_parallel_size):

        self.MANUAL_SEED = 42
        self.inds = None
        self.masks = None
        self.data_idx = 0
        self.MASK_PROB = 0.1
        self.EASY_MODE = False
        self.EASY_MODE_SIZ = 32

        tensor_model_parallel_size = 2 if self.world_size % 2 == 0 and self.world_size > 4 else 1
        pipeline_model_parallel_size = self.world_size // tensor_model_parallel_size

        override_args = {
            "micro_batch_size": 2,
            "num_layers": 16,
            "hidden_size": 256,
            "num_attention_heads": 8,
            "max_position_embeddings": 512,
            "seq_length": 512,
            "global_batch_size": 128,
            "pipeline_model_parallel_size": pipeline_model_parallel_size,
            "tensor_model_parallel_size": tensor_model_parallel_size,
            "bert_binary_head": False,
            "world_size": self.world_size,
            "rank": self.rank,
        }

        global_vars.set_global_variables(override_args=override_args, ignore_unknown_args=True)
        args = global_vars.get_args()

        self.fancy_data = self._download_fancy_data()
        self.effective_length = self.fancy_data.size(0) // args.seq_length
        self.effective_length = self.fancy_data.size(0) - args.seq_length

        if self.rank == 0:
            print(
                f'testing backend: {self.DISTRIBUTED_BACKEND} with virtual_pipeline_model_parallel_size: {virtual_pipeline_model_parallel_size}')
        async_comm = not args.sequence_parallel and virtual_pipeline_model_parallel_size is None
        self.data_idx = 0
        args.padded_vocab_size = 128  # needed in standalone gpt
        args.model_type = ModelType.encoder_or_decoder
        setup_microbatch_calculator(
            args.rank,
            args.rampup_batch_size,
            args.global_batch_size,
            args.micro_batch_size,
            args.data_parallel_size,
        )
        parallel_state.initialize_model_parallel(
            args.tensor_model_parallel_size,
            args.pipeline_model_parallel_size,
            virtual_pipeline_model_parallel_size,
            default_backend="nccl",
            p2p_backend=self.DISTRIBUTED_BACKEND,
        )

        tensor_parallel.random.model_parallel_cuda_manual_seed(0)
        model = build_model(
            bert_model_provider,
            wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1,
            virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
            cpu_offload=args.cpu_offload,
        )
        assert isinstance(model, list)
        assert len(model) == (
            1
            if virtual_pipeline_model_parallel_size is None
            else virtual_pipeline_model_parallel_size
        )
        _param_groups = _get_params_for_weight_decay_optimization(model)
        optim = torch.optim.Adam(_param_groups)
        self._train(
            model,
            optim,
            virtual_pipeline_model_parallel_size,
            args.pipeline_model_parallel_size,
            async_comm,
        )
        torch.cuda.synchronize()


class NcclBertTest(BertTestBase, NcclDistributedTestBase):
    @property
    def world_size(self) -> int:
        return min(torch.cuda.device_count(), 8)


@unittest.skipUnless(HAS_UCC, "requires pytorch to be built with native ucc")
class UccBertTest(BertTestBase, UccDistributedTestBase):
    @property
    def world_size(self) -> int:
        return min(torch.cuda.device_count(), 8)


if __name__ == "__main__":
    common_utils.run_tests()
