              
                                                      
                       

import imp
import torch
from einops import rearrange

from megatron.core import tensor_parallel, mpu
from megatron.core.models.multimodal.llava_model import LLaVAModel

from gpatch.core.utils import split_data_cp_rank
from gpatch.core.tensor_parallel.mappings import all_gather_to_context_parallel_region


                                                              
class Gemma3LLaVAModel(LLaVAModel):

    def _gather_image_embedings(self, image_embeddings):
        if self.context_parallel_lm < 2:
            return image_embeddings

        return all_gather_to_context_parallel_region(image_embeddings, 1,
                                                     torch.distributed.ReduceOp.SUM)

    def _preprocess_data(
        self,
        image_embeddings,
        language_embeddings,
        input_ids,
        loss_mask,
        labels,
        use_inference_kv_cache,
        inference_params,
        image_token_index,
        num_image_tiles,
    ):
        assert self.add_decoder, "input text preprocessing is only needed for the language model"
        assert input_ids is not None

        if image_embeddings is None:
            return language_embeddings, labels, loss_mask

                                           
                                                                                    
        if not self.pre_process and not self.post_process:
            return None, None, None

                                                                                 
        if use_inference_kv_cache:
            return language_embeddings, loss_mask, labels

        has_labels = labels is not None
        if has_labels:
            assert (
                labels.shape == loss_mask.shape
            ), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}"

        special_image_mask = (input_ids == image_token_index).unsqueeze(-1)
        special_image_mask = special_image_mask.expand_as(language_embeddings)

                                  
        image_embeddings = self._gather_image_embedings(image_embeddings)
        language_embeddings = language_embeddings.masked_scatter(special_image_mask,
                                                                 image_embeddings)
                                  
        language_embeddings = rearrange(language_embeddings, "b s h -> s b h")
        return language_embeddings, labels, loss_mask

    def _process_embedding_token_parallel(self, combined_embeddings, new_labels, new_loss_mask,
                                          packed_seq_params):
                                                                
                                                                 
        if not self.pre_process and not self.post_process:
            return combined_embeddings, new_labels, new_loss_mask, packed_seq_params

        shard_factor = seq_dim = None
        if self.pre_process:
            if self.context_parallel_lm > 1 and self.sequence_parallel_lm:
                shard_factor = self.tensor_model_parallel_size_lm * self.context_parallel_lm * 2
                seq_dim = 0
            elif self.context_parallel_lm > 1:
                shard_factor = self.context_parallel_lm * 2
                seq_dim = 0
            elif self.sequence_parallel_lm:
                shard_factor = self.tensor_model_parallel_size_lm
                seq_dim = 0

            assert (
                combined_embeddings.shape[seq_dim] %
                shard_factor == 0), f"Sequence length should be divisible by {shard_factor} for \
                Sequence/Context parallelism"

            if self.sequence_parallel_lm and self.tp_comm_overlap_lm:
                assert (combined_embeddings.shape[seq_dim] == self._language_max_sequence_length
                        ), f"TP Comm overlap either requires Vision+Text token length \
                == language_max_sequence_length"

        if self.context_parallel_lm > 1 and self.pre_process:
            combined_embeddings = split_data_cp_rank(combined_embeddings, self.context_parallel_lm,
                                                     0)

        if self.sequence_parallel_lm and self.pre_process:
            combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
                combined_embeddings)                   

        return combined_embeddings, new_labels, new_loss_mask, packed_seq_params
