# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.


import torch

from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from tests.unit_tests.test_utilities import Utils


class TestMultimodalProjector:

    def setup_method(self, method):
        Utils.initialize_model_parallel(1, 1)
        model_parallel_cuda_manual_seed(123)
        transformer_config = TransformerConfig(
            num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True
        )
        mlp_layer_spec = get_mlp_module_spec().submodules

        affine_layer_spec = MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=None)
        self.mlp = MultimodalProjector(
            config=transformer_config,
            submodules=mlp_layer_spec,
            projector_type="mlp",
            input_size=1024,
        )
        self.affine = MultimodalProjector(
            config=transformer_config,
            submodules=affine_layer_spec,
            projector_type="affine",
            input_size=1024,
        )

    def teardown_method(self, method):
        Utils.destroy_model_parallel()

    def test_constructor(self):
        assert isinstance(self.mlp, MultimodalProjector)
        assert isinstance(self.affine, MultimodalProjector)

        num_weights = sum([p.numel() for p in self.mlp.parameters()])
        assert num_weights == 280896

        num_weights = sum([p.numel() for p in self.affine.parameters()])
        assert num_weights == 65600

    def test_forward(self):
        self.mlp.cuda()
        self.affine.cuda()

        image_projection = torch.zeros((2, 1024)).cuda()

        logits = self.mlp.forward(image_projection)
        assert len(logits) == 2
        assert logits.shape == torch.Size([2, 64])

        logits = self.affine.forward(image_projection)
        assert len(logits) == 2
        assert logits.shape == torch.Size([2, 64])

    def test_save_load(self, tmp_path):
        path = tmp_path / "mlp.pt"
        torch.save(self.mlp.state_dict(), path)

        self.mlp.load_state_dict(torch.load(path))

        path = tmp_path / "affine.pt"
        torch.save(self.affine.state_dict(), path)

        self.affine.load_state_dict(torch.load(path))
