# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from contextlib import nullcontext
from copy import deepcopy
from types import SimpleNamespace

import pytest
import torch

from megatron.core import InferenceParams
from megatron.core import parallel_state as ps
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.multimodal import context_parallel
from megatron.core.models.multimodal.llava_model import LLaVAModel
from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_te_min_version
from megatron.training.global_vars import set_args
from tests.unit_tests.test_utilities import Utils


class TestLLaVAModel:
    @pytest.mark.internal  # The model is under active development and its methods may change.
    def setup_method(self, method):
        Utils.initialize_model_parallel(1, 1)
        model_parallel_cuda_manual_seed(123)

        self.language_hidden_size = 64
        self.language_num_attention_heads = 4

        language_config = TransformerConfig(
            num_layers=3,
            hidden_size=self.language_hidden_size,
            num_attention_heads=self.language_num_attention_heads,
            use_cpu_initialization=False,
        )
        vision_config = TransformerConfig(
            num_layers=2, hidden_size=16, num_attention_heads=2, use_cpu_initialization=False
        )
        vision_projection_config = TransformerConfig(
            num_layers=2,
            hidden_size=self.language_hidden_size,
            ffn_hidden_size=32,
            num_attention_heads=1,
            use_cpu_initialization=False,
        )

        language_layer_spec = get_gpt_layer_with_transformer_engine_spec()
        vision_layer_spec = deepcopy(language_layer_spec)
        vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules)

        language_config.language_model_type = "dummy"
        vision_config.vision_model_type = "clip"
        self.model = LLaVAModel(
            language_transformer_config=language_config,
            language_transformer_layer_spec=language_layer_spec,
            language_vocab_size=8192,
            language_max_sequence_length=4096,
            vision_transformer_config=vision_config,
            vision_transformer_layer_spec=vision_layer_spec,
            drop_vision_class_token=False,
            vision_projection_config=vision_projection_config,
            vision_projection_layer_spec=vision_projection_spec,
            img_h=336,
            img_w=336,
            patch_dim=14,
        )

    @pytest.mark.internal
    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    @pytest.mark.internal
    def test_constructor(self):
        assert isinstance(self.model, LLaVAModel)

        num_weights = sum([p.numel() for p in self.model.parameters()])
        assert num_weights == 1488736

    @pytest.mark.internal
    def test_set_input_tensor(self):
        expected_shape = (1, 2, 3, 4)
        input_tensor = torch.zeros(expected_shape)
        self.model.set_input_tensor(input_tensor)
        assert self.model.vision_model.decoder.input_tensor.shape == expected_shape

    @pytest.mark.internal
    def test_preprocess_data(self):
        self.model.cuda()

        hidden_size = 72

        # 3 images with 1 tile and 2 image with 2 tiles = 7 tiles.
        image_embeddings = (
            torch.arange(577 * 7 * hidden_size, dtype=torch.float)
            .reshape(577, 7, hidden_size)
            .cuda()
        )

        image_token_index = self.model.image_token_index
        input_ids = torch.arange(1024).expand(5, 1024).cuda()
        input_ids[0, 0] = image_token_index  # image before text
        input_ids[1, 100] = image_token_index  # image in between
        input_ids[2, -1] = image_token_index  # image at the end
        # input_ids[3] - no image
        input_ids[4, 50] = image_token_index  # two images in between
        input_ids[4, 150] = image_token_index

        # Using negative sign to distinguish from image embeddings.
        language_embeddings = (
            -torch.arange(5 * 1024 * hidden_size, dtype=torch.float)
            .reshape(5, 1024, hidden_size)
            .cuda()
        )

        # Labels are input_ids shifted to left by one.
        labels = torch.arange(1, 1025, dtype=torch.int).expand(5, 1024).cuda()
        # labels[0] - image token got dropped by shift to left by one.
        labels[1, 99] = image_token_index
        labels[2, -2] = image_token_index
        # labels[3] - no image.
        labels[4, 49] = image_token_index
        labels[4, 149] = image_token_index

        loss_mask = torch.ones((5, 1024), dtype=torch.float).cuda()
        # Mask some text inputs (the text mask should carry over)
        loss_mask[:2, :10] = 0.0
        loss_mask[:2, 110:120] = 0.0

        # Number of tiles for each image in the batch.
        num_image_tiles = torch.tensor([1, 2, 1, 2, 1], dtype=torch.int).cuda()

        use_inference_kv_cache = False
        inference_params = None

        embeddings, labels, loss_mask = self.model._preprocess_data(
            image_embeddings,
            language_embeddings,
            input_ids,
            loss_mask,
            labels,
            use_inference_kv_cache,
            inference_params,
            image_token_index,
            num_image_tiles,
        )

        img_seq_len = 577
        # The fifth sample has 2 images with 3 tiles and 1024 text tokens.
        max_seq_len = 3 * img_seq_len - 2 + 1024

        assert embeddings.shape == torch.Size((max_seq_len, 5, hidden_size))
        assert labels.shape == torch.Size((5, max_seq_len))
        assert loss_mask.shape == labels.shape

        # First sample where image is before text (index 0).
        expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda()
        expected_embeddings[:577] = image_embeddings[:, 0]
        expected_embeddings[577:1600] = language_embeddings[0, 1:]
        expected_embeddings[1600:] = 0  # padding

        expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda()
        expected_labels[:576] = -100  # image
        expected_labels[576:1600] = torch.arange(1, 1025, dtype=torch.int)
        expected_labels[1600:] = -100  # padding

        expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda()
        expected_loss_mask[:577] = 0
        expected_loss_mask[577:586] = 0
        expected_loss_mask[586:686] = 1
        expected_loss_mask[686:696] = 0
        expected_loss_mask[696:1600] = 1
        expected_loss_mask[1600:] = 0

        assert torch.allclose(embeddings[:, 0], expected_embeddings)
        assert torch.allclose(labels[0], expected_labels)
        assert torch.allclose(loss_mask[0], expected_loss_mask)

        # Second sample where image is in between (index 100). The image has 2 tiles.
        expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda()
        expected_embeddings[:100] = language_embeddings[1, :100]
        expected_embeddings[100:677] = image_embeddings[:, 1]
        expected_embeddings[677:1254] = image_embeddings[:, 2]
        expected_embeddings[1254:2177] = language_embeddings[1, 101:]
        expected_embeddings[2177:] = 0  # padding

        expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda()
        expected_labels[:99] = torch.arange(1, 100)
        expected_labels[99:1253] = -100  # image
        expected_labels[1253:2177] = torch.arange(101, 1025)
        expected_labels[2177:] = -100  # padding

        expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda()
        expected_loss_mask[:10] = 0
        expected_loss_mask[10:99] = 1
        # Last text position before the image is not required to predict the first image embedding.
        expected_loss_mask[99] = 0
        expected_loss_mask[100:1254] = 0
        expected_loss_mask[1254:1263] = 1
        expected_loss_mask[1263:1273] = 0
        expected_loss_mask[1273:2177] = 1
        expected_loss_mask[2177:] = 0  # padding

        assert torch.allclose(embeddings[:, 1], expected_embeddings)
        assert torch.allclose(labels[1], expected_labels)
        assert torch.allclose(loss_mask[1], expected_loss_mask)

        # Third sample where image is at the end.
        expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda()
        expected_embeddings[:1023] = language_embeddings[2, :1023]
        expected_embeddings[1023:1600] = image_embeddings[:, 3]
        expected_embeddings[1600:] = 0  # padding

        expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda()
        expected_labels[:1022] = torch.arange(1, 1023)
        expected_labels[1022:1599] = -100
        expected_labels[1599] = 1024
        expected_labels[1600:] = -100  # padding

        expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda()
        expected_loss_mask[:1022] = 1
        # Last text position before the image is not required to predict the first image embedding.
        expected_loss_mask[1022] = 0
        expected_loss_mask[1023:1600] = 0
        expected_loss_mask[1600:] = 0  # padding

        assert torch.allclose(embeddings[:, 2], expected_embeddings)
        assert torch.allclose(labels[2], expected_labels)
        assert torch.allclose(loss_mask[2], expected_loss_mask)

        # Fourth sample where there is no image.
        expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda()
        expected_embeddings[:1024] = language_embeddings[3]
        expected_embeddings[1024:] = 0  # padding

        expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda()
        expected_labels[:1024] = torch.arange(1, 1025)
        expected_labels[1024:] = -100  # padding

        expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda()
        expected_loss_mask[:1024] = 1
        expected_loss_mask[1024:] = 0  # padding

        assert torch.allclose(embeddings[:, 3], expected_embeddings)
        assert torch.allclose(labels[3], expected_labels)
        assert torch.allclose(loss_mask[3], expected_loss_mask)

        # Fifth sample has two images in between (indices 50 and 150). The first image has two tiles.
        expected_embeddings = torch.empty(max_seq_len, hidden_size).cuda()
        expected_embeddings[:50] = language_embeddings[4, :50]
        expected_embeddings[50:627] = image_embeddings[:, 4]  # two tiles
        expected_embeddings[627:1204] = image_embeddings[:, 5]
        expected_embeddings[1204:1303] = language_embeddings[4, 51:150]
        expected_embeddings[1303:1880] = image_embeddings[:, 6]
        expected_embeddings[1880:] = language_embeddings[4, 151:]

        expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda()
        expected_labels[:49] = torch.arange(1, 50)
        expected_labels[49:1203] = -100  # image
        expected_labels[1203:1302] = torch.arange(51, 150)
        expected_labels[1302:1879] = -100  # image
        expected_labels[1879:] = torch.arange(151, 1025)

        expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda()
        expected_loss_mask[:49] = 1
        expected_loss_mask[49:1204] = 0
        expected_loss_mask[1204:1302] = 1
        expected_loss_mask[1302:1880] = 0
        expected_loss_mask[1880:] = 1

        assert torch.allclose(embeddings[:, 4], expected_embeddings)
        assert torch.allclose(labels[4], expected_labels)
        assert torch.allclose(loss_mask[4], expected_loss_mask)

    @pytest.mark.internal
    def test_forward(self):
        self.model.cuda()

        # 3 images with 1 tile and 2 images with 2 tiles.
        img = torch.randn((7, 3, 336, 336)).cuda()

        image_token_index = self.model.image_token_index
        input_ids = torch.randint(0, 2048, (5, 1024)).cuda()
        input_ids[0, 0] = image_token_index  # image before text
        input_ids[1, 100] = image_token_index  # image in between
        input_ids[2, -1] = image_token_index  # image at the end
        # input_ids[3] - no image
        input_ids[4, 50] = image_token_index
        input_ids[4, 150] = image_token_index

        position_ids = torch.arange(0, 1024, dtype=torch.int).expand(5, 1024).cuda()

        loss_mask = torch.ones((5, 1024)).cuda()

        attention_mask = None  # Causal.

        labels = torch.randint(0, 2048, (5, 1024)).cuda()
        labels[1, 99] = image_token_index
        labels[2, -2] = image_token_index

        num_image_tiles = torch.tensor([1, 2, 1, 2, 1], dtype=torch.int).cuda()

        # Try with labels.
        loss, new_loss_mask = self.model.forward(
            img,
            input_ids,
            position_ids,
            attention_mask,
            labels,
            loss_mask,
            num_image_tiles=num_image_tiles,
        )

        # The maximum sequence length is given by the sample with 2 images in 3 tiles, minus two image token indices, plus other text tokens.
        img_seq_len = 577
        max_seq_len = img_seq_len * 3 - 2 + 1024
        assert loss.shape == new_loss_mask.shape == torch.Size((5, max_seq_len))

        # Try with labels and PackedSeqParams. Only micro batch size 1 is supported in this mode.
        packed_seq_params = PackedSeqParams(
            qkv_format="thd",
            cu_seqlens_q=torch.tensor(
                [0, 512, 1024, 1600], dtype=torch.int32
            ).cuda(),  # Just example values.
            cu_seqlens_kv=torch.tensor([0, 512, 1024, 1600], dtype=torch.int32).cuda(),
            max_seqlen_q=torch.tensor(1600, dtype=torch.int32).cuda(),
            max_seqlen_kv=torch.tensor(1600, dtype=torch.int32).cuda(),
        )

        # NOTE: Packing is only supported with BF16. Use BF16 here and switch back to default.
        self.model.to(torch.bfloat16)
        loss, new_loss_mask = self.model.forward(
            img[:1].to(torch.bfloat16),
            input_ids[:1],
            position_ids[:1],
            attention_mask,
            labels[:1],
            loss_mask[:1],
            num_image_tiles=num_image_tiles[:1],
            packed_seq_params=packed_seq_params,
        )
        self.model.to(torch.float32)

        # 1600 = 577 (img_seq_len) + 1024 (text tokens in the first sample) - 1 (image token).
        assert loss.shape == new_loss_mask.shape == torch.Size((1, 1600))

        # Try text-only input.
        loss, new_loss_mask = self.model.forward(
            torch.tensor([], dtype=torch.float).cuda(),
            torch.randint(0, 2048, (5, 1024)).cuda(),
            position_ids,
            attention_mask,
            torch.randint(0, 2048, (5, 1024)).cuda(),
            loss_mask,
            num_image_tiles=torch.tensor([], dtype=torch.int).cuda(),
        )

        assert loss.shape == new_loss_mask.shape == torch.Size((5, 1024))

        # Try without labels and without inference params.
        logits, _ = self.model.forward(
            img,
            input_ids,
            position_ids,
            attention_mask,
            labels=None,
            loss_mask=None,
            num_image_tiles=num_image_tiles,
        )
        assert logits.shape == torch.Size((5, max_seq_len, 8192))

        # Try without labels and with inference params.
        inference_params = InferenceParams(5, max_seq_len)
        logits, _ = self.model.forward(
            img,
            input_ids,
            position_ids,
            attention_mask,
            labels=None,
            loss_mask=None,
            num_image_tiles=num_image_tiles,
            inference_params=inference_params,
        )
        assert logits.shape == torch.Size((5, max_seq_len, 8192))

        # Check KV cache got populated correctly.
        kv_dict = inference_params.key_value_memory_dict

        assert kv_dict["image_tokens_count"] == 577 * 7
        for layer_no in range(1, 4):  # 3 layers in the model.
            layer_kv = kv_dict[layer_no]
            # Expected shape is [sequence_len, batch_size, num_heads, hidden_size_per_head]
            assert (
                layer_kv[0].shape
                == layer_kv[1].shape
                == torch.Size((max_seq_len, 5, self.language_num_attention_heads, 16))
            )

    @pytest.mark.internal
    def test_forward_fsdp(self):
        """Test FSDP workaround for text-only data.

        FSDP can hang with text-only data. As a workaround, we run the vision model with a dummy image,
        but then effectively discard the image embeddings.
        """
        self.model.cuda()

        # Dummy image for the FSDP workaround but not image tiles.
        img = torch.zeros((1, 3, 336, 336)).cuda()
        num_image_tiles = torch.tensor([], dtype=torch.int).cuda()

        # No image tag in the input ids (text-only sample).
        image_token_index = self.model.image_token_index
        input_ids = torch.arange(1024, device="cuda").unsqueeze(0)
        assert (
            torch.sum(input_ids == image_token_index) == 0
        ), "expected no image tag in the input ids"

        position_ids = torch.arange(1024, device="cuda").unsqueeze(0)

        loss_mask = torch.ones((1, 1024), device="cuda")

        attention_mask = None  # Causal.

        labels = torch.arange(1, 1025, device="cuda").unsqueeze(0)

        # Mock the FSDP attribute.
        self.model.vision_model._is_fsdp_managed_module = True
        loss, new_loss_mask = self.model.forward(
            img,
            input_ids,
            position_ids,
            attention_mask,
            labels,
            loss_mask,
            num_image_tiles=num_image_tiles,
        )
        self.model.vision_model._is_fsdp_managed_module = False

        assert loss.shape == new_loss_mask.shape == torch.Size((1, 1024))

    @pytest.mark.internal
    def test_save_load(self, tmp_path):
        path = tmp_path / "model.pt"
        torch.save(self.model.state_dict(), path)

        self.model.load_state_dict(torch.load(path))

    @pytest.mark.internal
    def test_freeze(self):
        self.model.freeze(
            freeze_language_model=True, freeze_vision_model=True, freeze_vision_projection=False
        )

        for module in [self.model.language_model, self.model.vision_model]:
            for param in module.parameters():
                assert not param.requires_grad

        for param in self.model.vision_projection.parameters():
            assert param.requires_grad


@pytest.fixture(scope='class', params=["siglip", "radio-g"])
def setup_and_teardown_llava_model(request):
    Utils.initialize_model_parallel(1, 1)
    model_parallel_cuda_manual_seed(123)

    language_config = TransformerConfig(
        num_layers=3, hidden_size=128, num_attention_heads=8, use_cpu_initialization=False
    )
    vision_config = TransformerConfig(
        num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=False
    )
    vision_projection_config = TransformerConfig(
        num_layers=2,
        hidden_size=128,
        ffn_hidden_size=72,
        num_attention_heads=1,
        use_cpu_initialization=False,
    )

    language_layer_spec = get_gpt_layer_with_transformer_engine_spec()
    vision_layer_spec = deepcopy(language_layer_spec)
    vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules)

    language_config.language_model_type = "dummy"
    vision_model_type = request.param
    vision_config.vision_model_type = vision_model_type
    model = LLaVAModel(
        language_transformer_config=language_config,
        language_transformer_layer_spec=language_layer_spec,
        language_vocab_size=2048,
        language_max_sequence_length=4096,
        vision_transformer_config=vision_config,
        vision_transformer_layer_spec=vision_layer_spec,
        drop_vision_class_token=False,
        vision_projection_config=vision_projection_config,
        vision_projection_layer_spec=vision_projection_spec,
        img_h=336,
        img_w=336,
        patch_dim=14,
    )

    yield model, vision_model_type

    Utils.destroy_model_parallel()


class TestLLaVAModelVisionEncoders:
    num_weights_by_encoder = {"siglip": 1832456, "radio-g": 2844552}

    @pytest.mark.internal
    def test_constructor(self, setup_and_teardown_llava_model):
        model, vision_model_type = setup_and_teardown_llava_model
        assert isinstance(model, LLaVAModel)

        num_weights = sum([p.numel() for p in model.parameters()])
        assert num_weights == self.num_weights_by_encoder[vision_model_type]

    @pytest.mark.internal
    def test_set_input_tensor(self, setup_and_teardown_llava_model):
        model, _ = setup_and_teardown_llava_model
        expected_shape = (1, 2, 3, 4)
        input_tensor = torch.zeros(expected_shape)
        model.set_input_tensor(input_tensor)
        assert model.vision_model.decoder.input_tensor.shape == expected_shape


def create_test_args(cp_size, sequence_parallel):
    # Set dummy values for the args.
    args = SimpleNamespace()
    args.context_parallel_size = cp_size
    args.sequence_parallel = sequence_parallel

    return args


class TestLLaVAModelTokenParallel:

    def _init_llava_model(self, cp_size, tp_size, sequence_parallel):
        language_hidden_size = 64
        language_num_attention_heads = 16

        language_config = TransformerConfig(
            num_layers=3,
            hidden_size=language_hidden_size,
            num_attention_heads=language_num_attention_heads,
            use_cpu_initialization=False,
            tensor_model_parallel_size=tp_size,
            sequence_parallel=sequence_parallel,
            context_parallel_size=cp_size,
        )
        # SP and CP are not yet supported for the Vision Backbone
        vision_config = TransformerConfig(
            num_layers=2,
            hidden_size=16,
            num_attention_heads=8,
            use_cpu_initialization=False,
            tensor_model_parallel_size=tp_size,
            sequence_parallel=False,
            context_parallel_size=1,
        )
        vision_projection_config = TransformerConfig(
            num_layers=2,
            hidden_size=language_hidden_size,
            ffn_hidden_size=128,
            num_attention_heads=8,
            use_cpu_initialization=False,
            tensor_model_parallel_size=tp_size,
            sequence_parallel=False,
            context_parallel_size=1,
        )

        language_layer_spec = get_gpt_layer_with_transformer_engine_spec()
        # SP/CP either requires user to ensure token lengths do not require padding OR change mask type to padding
        if (
            language_layer_spec.submodules.self_attention.params.get('attn_mask_type', '')
            == AttnMaskType.causal
        ):
            language_layer_spec.submodules.self_attention.params['attn_mask_type'] = (
                AttnMaskType.padding_causal
            )
        elif (
            language_layer_spec.submodules.self_attention.params.get('attn_mask_type', '')
            == AttnMaskType.no_mask
        ):
            language_layer_spec.submodules.self_attention.params['attn_mask_type'] = (
                AttnMaskType.padding
            )

        vision_layer_spec = deepcopy(language_layer_spec)
        vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules)

        language_config.language_model_type = "dummy"
        vision_config.vision_model_type = "clip"
        model = LLaVAModel(
            language_transformer_config=language_config,
            language_transformer_layer_spec=language_layer_spec,
            language_vocab_size=8192,
            language_max_sequence_length=4096,
            vision_transformer_config=vision_config,
            vision_transformer_layer_spec=vision_layer_spec,
            drop_vision_class_token=False,
            vision_projection_config=vision_projection_config,
            vision_projection_layer_spec=vision_projection_spec,
            img_h=336,
            img_w=336,
            patch_dim=14,
        )

        return model

    @pytest.mark.internal
    def setup_method(self, method):
        Utils.destroy_model_parallel()

    @pytest.mark.internal
    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    @pytest.mark.internal
    @pytest.mark.parametrize(
        "cp_size,tp_size,sequence_parallel,padding",
        [(1, 8, True, True), (2, 4, False, True), (2, 4, True, False), (2, 4, True, True)],
    )
    def test_process_embedding_token_parallel(self, cp_size, tp_size, sequence_parallel, padding):
        """Test _process_embedding_token_parallel.

        Note: This test requires TE version >= 1.10.0 to run properly.
        """
        Utils.initialize_model_parallel(
            tensor_model_parallel_size=tp_size, context_parallel_size=cp_size
        )
        model_parallel_cuda_manual_seed(123)

        # TE version must be at least 1.10.0 if using context parallelism. Exit otherwise.
        ctx = (
            nullcontext()
            if (is_te_min_version("1.10.0") or cp_size <= 1)
            else pytest.raises(AssertionError)
        )
        model = None
        with ctx:
            model = self._init_llava_model(cp_size, tp_size, sequence_parallel)

        if model is None:
            return

        model.cuda()

        args = create_test_args(cp_size, sequence_parallel)
        set_args(args)

        batch_size = 2
        if padding:
            combined_valid_seqlen = 2049
            combined_padded_seqlen = 2064
        else:
            combined_valid_seqlen = 2048
            combined_padded_seqlen = 2048

        if cp_size > 1:
            combined_embeddings = torch.ones(
                [batch_size, combined_padded_seqlen, 4096], device='cuda', dtype=torch.bfloat16
            )  # [B, S, H]
        else:
            combined_embeddings = torch.ones(
                [combined_padded_seqlen, batch_size, 4096], device='cuda', dtype=torch.bfloat16
            )  # [S, B, H]
        new_labels = torch.ones(
            [batch_size, combined_padded_seqlen], device='cuda', dtype=torch.bfloat16
        )  # [B, S]
        new_loss_mask = torch.ones(
            [batch_size, combined_padded_seqlen], device='cuda', dtype=torch.bfloat16
        )  # [B, S]

        cu_seqlens = torch.arange(
            0,
            (batch_size + 1) * (combined_valid_seqlen),
            step=(combined_valid_seqlen),
            dtype=torch.int32,
            device=combined_embeddings.device,
        )
        cu_seqlens_padded = torch.arange(
            0,
            (batch_size + 1) * (combined_padded_seqlen),
            step=(combined_padded_seqlen),
            dtype=torch.int32,
            device=combined_embeddings.device,
        )

        qkv_format = 'sbhd'  # Default format when not using padding
        if cp_size > 1 and padding:
            # Reshape from [B,S] to [1,T]
            combined_embeddings = (
                combined_embeddings.contiguous()
                .view(combined_embeddings.shape[0] * combined_embeddings.shape[1], -1)
                .unsqueeze(0)
            )
            new_labels = new_labels.view(new_labels.shape[0] * new_labels.shape[1]).unsqueeze(0)
            new_loss_mask = new_loss_mask.view(
                new_loss_mask.shape[0] * new_loss_mask.shape[1]
            ).unsqueeze(0)
            qkv_format = 'thd'

        packed_seq_params = PackedSeqParams(
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_kv=cu_seqlens,
            cu_seqlens_q_padded=cu_seqlens_padded,
            cu_seqlens_kv_padded=cu_seqlens_padded,
            max_seqlen_q=combined_padded_seqlen,
            max_seqlen_kv=combined_padded_seqlen,
            qkv_format=qkv_format,
        )

        combined_embeddings, new_labels, new_loss_mask, packed_seq_params = (
            model._process_embedding_token_parallel(
                combined_embeddings, new_labels, new_loss_mask, packed_seq_params
            )
        )

        # Check if output shape is as expected
        if cp_size > 1 and sequence_parallel:
            if padding:
                # THD format
                assert combined_embeddings.shape[0] == batch_size * (
                    combined_padded_seqlen / (tp_size * cp_size)
                )
                assert combined_embeddings.shape[1] == 1
            else:
                # SBHD format
                assert combined_embeddings.shape[0] == (
                    combined_padded_seqlen / (tp_size * cp_size)
                )
                assert combined_embeddings.shape[1] == batch_size
        elif cp_size > 1:
            if padding:
                # THD format
                assert combined_embeddings.shape[0] == batch_size * (
                    combined_padded_seqlen / cp_size
                )
                assert combined_embeddings.shape[1] == 1
            else:
                # SBHD format
                assert combined_embeddings.shape[0] == (combined_padded_seqlen / cp_size)
                assert combined_embeddings.shape[1] == batch_size
        else:
            # SBHD format
            assert combined_embeddings.shape[0] == combined_padded_seqlen / tp_size
            assert combined_embeddings.shape[1] == batch_size


def count_parameters(model):
    return sum(p.numel() for p in model.parameters())


@pytest.mark.internal  # The model is under active development and its methods may change.
@pytest.mark.parametrize(
    'dtp, dpp, etp, epp', [(1, 1, 1, 0), (1, 1, 1, 1), (2, 1, 2, 0), (2, 3, 2, 1), (2, 4, 2, 0)]
)
def test_llava_model_parallelism(dtp, dpp, etp, epp):
    """
    The purpose of this test is to check that vit, vision projection and lm layer
    counts across tensor and pipeline parallel ranks match the counts in the
    non-model-parallel case, i.e. tp==1, pp==1, etp==1, epp==0
    """

    language_hidden_size = 64
    language_num_attention_heads = 4

    # First initialize a single GPU model to get baseline parameter and layer counts
    Utils.initialize_model_parallel(
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
        encoder_tensor_model_parallel_size=1,
        encoder_pipeline_model_parallel_size=0,
    )
    model_parallel_cuda_manual_seed(123)

    language_config = TransformerConfig(
        num_layers=12,
        hidden_size=language_hidden_size,
        num_attention_heads=language_num_attention_heads,
        use_cpu_initialization=False,
    )
    language_config.tensor_model_parallel_size = dtp
    language_config.pipeline_model_parallel_size = dpp

    vision_config = TransformerConfig(
        num_layers=4, hidden_size=16, num_attention_heads=2, use_cpu_initialization=False
    )
    vision_config.tensor_model_parallel_size = etp
    vision_config.pipeline_model_parallel_size = 1

    vision_projection_config = TransformerConfig(
        num_layers=2,
        hidden_size=language_hidden_size,
        ffn_hidden_size=32,
        num_attention_heads=1,
        use_cpu_initialization=False,
    )
    vision_projection_config.tensor_model_parallel_size = etp
    vision_projection_config.pipeline_model_parallel_size = 1

    language_layer_spec = get_gpt_layer_with_transformer_engine_spec()
    vision_layer_spec = get_vit_layer_with_transformer_engine_spec()
    vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules)

    language_config.language_model_type = "dummy"
    vision_config.vision_model_type = "clip"
    non_parallel_model = LLaVAModel(
        language_transformer_config=language_config,
        language_transformer_layer_spec=language_layer_spec,
        language_vocab_size=8192,
        language_max_sequence_length=4096,
        vision_transformer_config=vision_config,
        vision_transformer_layer_spec=vision_layer_spec,
        drop_vision_class_token=False,
        vision_projection_config=vision_projection_config,
        vision_projection_layer_spec=vision_projection_spec,
        img_h=336,
        img_w=336,
        patch_dim=14,
    )

    base_vit_params = sum(p.numel() for p in non_parallel_model.vision_model.parameters())
    base_proj_params = sum(p.numel() for p in non_parallel_model.vision_projection.parameters())

    base_vit_layers = len(non_parallel_model.vision_model.decoder.layers)

    Utils.destroy_model_parallel()

    # Next initialize a model parallel version to get test parameter and layer counts
    Utils.initialize_model_parallel(
        tensor_model_parallel_size=dtp,
        pipeline_model_parallel_size=dpp,
        encoder_tensor_model_parallel_size=etp,
        encoder_pipeline_model_parallel_size=epp,
    )
    model_parallel_cuda_manual_seed(123)

    pp_rank = ps.get_pipeline_model_parallel_rank()
    pp_world_size = ps.get_pipeline_model_parallel_world_size()
    tp_world_size = ps.get_tensor_model_parallel_world_size()

    pre_process = True if (pp_rank == 0 or (pp_rank == 1 and epp == 1)) else False
    post_process = (
        True if ((pp_rank == 0 and epp == 1) or (pp_rank == pp_world_size - 1)) else False
    )
    add_encoder = True if pp_rank == 0 else False
    add_decoder = False if (pp_rank == 0 and epp == 1) else True

    language_config = TransformerConfig(
        num_layers=12,
        hidden_size=language_hidden_size,
        num_attention_heads=language_num_attention_heads,
        use_cpu_initialization=False,
    )
    language_config.tensor_model_parallel_size = dtp
    language_config.pipeline_model_parallel_size = dpp

    vision_config = TransformerConfig(
        num_layers=4, hidden_size=16, num_attention_heads=2, use_cpu_initialization=False
    )
    vision_config.tensor_model_parallel_size = etp
    vision_config.pipeline_model_parallel_size = 1

    vision_projection_config = TransformerConfig(
        num_layers=2,
        hidden_size=language_hidden_size,
        ffn_hidden_size=32,
        num_attention_heads=1,
        use_cpu_initialization=False,
    )
    vision_projection_config.tensor_model_parallel_size = etp
    vision_projection_config.pipeline_model_parallel_size = 1

    language_layer_spec = get_gpt_layer_with_transformer_engine_spec()
    vision_layer_spec = get_vit_layer_with_transformer_engine_spec()
    vision_projection_spec = deepcopy(vision_layer_spec.submodules.mlp.submodules)

    language_config.language_model_type = "dummy"
    vision_config.vision_model_type = "clip"
    model = LLaVAModel(
        language_transformer_config=language_config,
        language_transformer_layer_spec=language_layer_spec,
        language_vocab_size=8192,
        language_max_sequence_length=4096,
        vision_transformer_config=vision_config,
        vision_transformer_layer_spec=vision_layer_spec,
        drop_vision_class_token=False,
        vision_projection_config=vision_projection_config,
        vision_projection_layer_spec=vision_projection_spec,
        img_h=336,
        img_w=336,
        patch_dim=14,
        pre_process=pre_process,
        post_process=post_process,
        add_encoder=add_encoder,
        add_decoder=add_decoder,
    )

    if epp == 1:
        if pp_rank == 0:
            # should be in a etp sized tp group
            assert tp_world_size == etp
            # there should only be a single pipeline rank
            assert pp_world_size == epp + dpp
            # should not be inside decoder
            assert not ps.is_inside_decoder()
            # should be inside encoder
            assert ps.is_inside_encoder()
        elif pp_rank != 0:
            # non-encoder ranks should be in a dtp sized tp group
            assert tp_world_size == dtp
            # check we're inside the decoder
            assert ps.is_inside_decoder()
            # check we're not inside the encoder
            assert not ps.is_inside_encoder()
    elif epp == 0:
        if pp_rank == 0:
            # check we're inside the encoder and decoder
            assert ps.is_inside_encoder()
            assert ps.is_inside_decoder()
        elif pp_rank != 0:
            # check we're inside the decoder only and there's no vision_model
            assert not ps.is_inside_encoder()
            assert ps.is_inside_decoder()
            assert model.vision_model is None
            assert model.vision_projection is None

    if ps.is_inside_encoder():
        # Check num vit layers - epp > 1 not supported
        test_vit_layers = len([p for p in model.vision_model.decoder.layers])
        assert test_vit_layers == base_vit_layers

        # Check all vit params are present
        test_vit_tp_params = sum(
            [
                p.numel()
                for p in model.vision_model.parameters()
                if hasattr(p, 'tensor_model_parallel')
            ]
        )
        test_vit_non_tp_params = sum(
            [
                p.numel()
                for p in model.vision_model.parameters()
                if not hasattr(p, 'tensor_model_parallel')
            ]
        )
        group = ps.get_tensor_model_parallel_group()
        test_vit_params_tensor = torch.tensor([test_vit_tp_params], dtype=torch.int32).cuda()
        torch.distributed.all_reduce(
            test_vit_params_tensor, op=torch.distributed.ReduceOp.SUM, group=group
        )
        total_test_vit_tp_params = test_vit_params_tensor.item()
        assert total_test_vit_tp_params + test_vit_non_tp_params == base_vit_params

        # Check all vision projection params are present
        test_proj_tp_params = sum(
            [
                p.numel()
                for p in model.vision_projection.parameters()
                if hasattr(p, 'tensor_model_parallel')
            ]
        )
        test_proj_non_tp_params = sum(
            [
                p.numel()
                for p in model.vision_projection.parameters()
                if not hasattr(p, 'tensor_model_parallel')
            ]
        )
        test_proj_params_tensor = torch.tensor([test_proj_tp_params], dtype=torch.int32).cuda()
        torch.distributed.all_reduce(
            test_proj_params_tensor, op=torch.distributed.ReduceOp.SUM, group=group
        )
        total_test_proj_tp_params = test_proj_params_tensor.item()
        assert total_test_proj_tp_params + test_proj_non_tp_params == base_proj_params
    else:
        # check ranks that aren't inside encoder have no vit
        assert model.vision_model is None
        assert model.vision_projection is None

    Utils.destroy_model_parallel()
    torch.cuda.empty_cache()


@pytest.mark.internal
@pytest.mark.parametrize(
    "cp_size, tp_size, has_sp, seq_len, expected_padding",
    [(1, 1, False, 99, 0), (2, 2, True, 99, 5), (2, 2, False, 99, 1)],
)
def test_get_padding(cp_size, tp_size, has_sp, seq_len, expected_padding):
    """Test calculating padding for context parallel."""
    padding = context_parallel.get_padding(seq_len, cp_size, tp_size, has_sp)

    assert padding == expected_padding


@pytest.mark.internal
@pytest.mark.parametrize(
    "tokens, img_seq_len, padding_needed, cp_size, expected_seq_len",
    [(torch.ones((1, 100)), 100, 0, 2, 200), (torch.ones((1, 100)), 128, 1, 2, 227)],
)
def test_get_packed_seq_params(tokens, img_seq_len, padding_needed, cp_size, expected_seq_len):
    """Test creating PackedSeqParams for context parallel."""
    packed_seq_params = context_parallel.get_packed_seq_params(
        tokens, img_seq_len, padding_needed, cp_size
    )

    assert torch.equal(
        packed_seq_params.cu_seqlens_q, torch.tensor([0, expected_seq_len], dtype=torch.int32)
    )

    if padding_needed > 0:
        padded_seq_len = tokens.shape[1] + img_seq_len
        assert torch.equal(
            packed_seq_params.cu_seqlens_q_padded,
            torch.tensor([0, padded_seq_len], dtype=torch.int32),
        )
        assert packed_seq_params.max_seqlen_q == padded_seq_len
