from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.vision.multimodal_projector import MultimodalProjector


def multimodal_projector_init(
    self,
    config: TransformerConfig,
    submodules: MLPSubmodules,
    projector_type: str,
    input_size: int,
):
    super(MultimodalProjector, self).__init__(config=config)
    self.projector_type = projector_type
    assert submodules is not None, "MLPSubmodules must be provided"
    if self.projector_type == "mlp":
        self.encoder = MLP(config=config, submodules=submodules, input_size=input_size)
    elif self.projector_type == "affine":
        self.encoder = build_module(
            submodules.linear_fc1,
            input_size,
            config.hidden_size,
            config=config,
            init_method=config.init_method,
            gather_output=True,
            bias=config.add_bias_linear,
            skip_bias_add=True,
            is_expert=False,
            tp_comm_buffer_name=None,
        )
    elif self.projector_type == "custom_cls":
        self.encoder = config.mm_projector_cls(
            config=config,
            submodules=submodules,
            input_size=input_size,
        )
    else:
        raise Exception(f"Unsupported multimodal projection type {self.projector_type}")
