from typing import Union

import torch

from megatron.core import InferenceParams
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import TransformerConfig
from megatron.core import tensor_parallel
from megatron.core import mpu

from gpatch.core.models.vision.qwen2vl_vit_model import (
    Qwen2VisionModel,
    Qwen2VLTransformerConfig,
    Qwen2P5VLTransformerConfig,
)
from gpatch.core.models.gpt.gpt_model import Qwen2VLGPTModel
from gpatch.core.utils import split_data_cp_rank


                                                              
class Qwen2VLModel(MegatronModule):
    """Qwen2VL multi-modal model.

    Args:
        language_transformer_config (TransformerConfig): Transformer config for the language model.
        language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the language model.
        language_vocab_size (int): Language model vocabulary size.
        language_max_sequence_length (int): Language model maximum sequence length. This is used for positional embedding.
        vision_transformer_config (TransformerConfig): Transformer config for the vision model.
        vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the vision model.
        vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to language model inputs.
        vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision projection.
        vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP.
        parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This is typically True for training and False for inference.
        language_position_embedding_type (str): Position embedding type to use in the language model. Default learned absolute.
        language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings in the language model. Defaults to 1.0.
        pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True.
        post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True.
        add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder
            will live on only a subset of the pipeline stages (specifically, only the first stage).
        add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder
            will live on only a subset of the pipeline stages (specifically, every stage after the first one).
        img_h (int): The height of each image that the ViT will see.
        img_w (int): The width of each image that the ViT will see.
        patch_dim (int): The size of each patch side.
        img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be inserted. Defaults to 0.
    """

    def __init__(
        self,
        language_transformer_config: Union[Qwen2VLTransformerConfig, Qwen2P5VLTransformerConfig],
        language_transformer_layer_spec: ModuleSpec,
        language_vocab_size: int,
        language_max_sequence_length: int,
        vision_transformer_config: Union[Qwen2VLTransformerConfig, Qwen2P5VLTransformerConfig],
        vision_transformer_layer_spec: ModuleSpec,
        vision_projection_config: TransformerConfig,
        vision_projection_layer_spec: ModuleSpec,
        vision_projection_type: str = "mlp",
        parallel_output: bool = True,
        language_position_embedding_type: str = 'rope',
        language_rotary_percent: float = 1.0,
        pre_process: bool = True,
        post_process: bool = True,
        add_encoder: bool = True,
        add_decoder: bool = True,
        language_rotary_base: int = 10000,
        share_embeddings_and_output_weights: bool = False,
        fp16_lm_cross_entropy: bool = False,
        vision_model_class=Qwen2VisionModel,
    ) -> None:
        super().__init__(config=language_transformer_config)
        self.vision_projection_config = vision_projection_config

        self.pre_process = pre_process
        self.post_process = post_process
        self.add_encoder = add_encoder
        self.add_decoder = add_decoder

        self.encoder_hidden_state = None
        self.vision_model = None
        self.vision_projection = None
        self.language_model = None

        self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size

                                                                        
                                                                                               
        self.share_embeddings_and_output_weights = False
        if self.pre_process:
            self.vision_model = vision_model_class(vision_transformer_config,
                                                   vision_transformer_layer_spec,
                                                   vision_projection_config,
                                                   vision_projection_layer_spec,
                                                   projection_type=vision_projection_type,
                                                   pre_process=True,
                                                   post_process=True)

        self.language_model = Qwen2VLGPTModel(
            config=language_transformer_config,
            transformer_layer_spec=language_transformer_layer_spec,
            vocab_size=language_vocab_size,
            max_sequence_length=language_max_sequence_length,
            parallel_output=parallel_output,
            position_embedding_type=language_position_embedding_type,
            rotary_percent=language_rotary_percent,
            pre_process=self.pre_process,
            post_process=self.post_process,
            rotary_base=language_rotary_base,
            fp16_lm_cross_entropy=fp16_lm_cross_entropy,
            share_embeddings_and_output_weights=share_embeddings_and_output_weights)
        self.share_embeddings_and_output_weights = (
            self.language_model.share_embeddings_and_output_weights)

    def shared_embedding_or_output_weight(self):
        """This is a convenience method to surface the language model's word embeddings, which is
        necessary for `finalize_model_grads._allreduce_word_embedding_grads`."""
        if self.add_decoder:
            return self.language_model.shared_embedding_or_output_weight()
        return None

    def set_input_tensor(self, input_tensor) -> None:
                                                                               
                                    
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]
        assert len(input_tensor) == 1, 'input_tensor should only be length 1 for Qwen2VL'

        if self.pre_process:
            self.encoder_hidden_state = input_tensor[0]
        else:
            self.language_model.set_input_tensor(input_tensor[0])

    def freeze(self, freeze_language_model: bool, freeze_vision_model: bool,
               freeze_vision_projection: bool):
        """Freeze model modules.

        Make specific modules non-trainable by setting requires_grad to False for the module's parameters.

        Args:
            freeze_language_model (bool): Freeze the language model module.
            freeze_vision_model (bool): Freeze the vision model module.
            freeze_vision_projection (bool): Freeze the vision projection module.
        """
        modules = []
        if freeze_language_model and self.language_model is not None:
            modules.append(self.language_model)
        if freeze_vision_model and self.vision_model is not None:
            modules.append(self.vision_model.patch_embed)
            modules.append(self.vision_model.decoder)
        if freeze_vision_projection and self.vision_model is not None and self.vision_model.projection is not None:
            modules.append(self.vision_model.projection)
            modules.append(self.vision_model.decoder.final_layernorm)

        for module in modules:
            for param in module.parameters():
                param.requires_grad = False

                                 
                                                                 
                       
        if freeze_vision_model and not freeze_vision_projection:
            if self.vision_model is not None:
                for param in self.vision_model.decoder.final_layernorm.parameters():
                    param.requires_grad = True

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        vision_data: torch.Tensor = None,
        vision_grid_thw: torch.Tensor = None,
        video_start_index: int = -1,
        image_input_mask: torch.Tensor = None,
        video_input_mask: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
        labels: torch.Tensor = None,
        inference_params: InferenceParams = None,
        packed_seq_params: PackedSeqParams = None,
        image_padded: bool = False,
        extra_block_kwargs: dict = None,
    ) -> torch.Tensor:
        """Forward function of the Qwen2VL model.

        Args:
            image_data (torch.Tensor): input image of shape [total_thw_size, n_features].
            input_ids (torch.Tensor): input text ids [batch, text_seq_len].
            position_ids (torch.Tensor): input text position ids [batch, text_seq_len].
            attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len].
            labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].
            inference_params (InferenceParams): Inference-time parameters including KV cache.

            video_start_index:
                0 -- all video
                len(video_seq) -- all image
                others -- mixture
            *_input_mask: should not be None in the first PP stage
        Returns:
            output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size].
        """
        use_inference_kv_cache = (inference_params is not None and "image_tokens_count"
                                  in inference_params.key_value_memory_dict)
        if use_inference_kv_cache:
            raise NotImplementedError()

        if self.pre_process:
            if vision_data is None:
                vision_embeds = None
            else:
                vision_embeds = self.vision_model(
                    vision_data=
                    vision_data,                                                                   
                    grid_thw=vision_grid_thw,                                     
                    image_padded=image_padded,
                )

                                                                                                          
                                                                                                                   
            if inference_params is not None:
                raise NotImplementedError()
                                                                                  
                                                
                   

            language_embeddings: torch.Tensor = self.language_model.embedding(
                input_ids=input_ids,
                position_ids=None                 
            ).clone()                                 
                                                                                                                              
            if use_inference_kv_cache:
                                                                                
                combined_embeddings = language_embeddings
            else:
                if vision_embeds is None:
                    image_embeds = None
                    video_embeds = None
                else:
                                                        
                    if video_start_index == 0:
                        image_embeds = None
                        video_embeds = vision_embeds
                    elif video_start_index == vision_embeds.shape[0]:
                        image_embeds = vision_embeds
                        video_embeds = None
                    elif 0 < video_start_index < vision_embeds.shape[0]:
                        image_embeds = vision_embeds[:video_start_index]
                        video_embeds = vision_embeds[video_start_index:]
                    else:
                        raise ValueError(
                            f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got {video_start_index}"
                        )

                if self.config.sequence_parallel:
                    language_embeddings = tensor_parallel.gather_from_sequence_parallel_region(
                        language_embeddings)
                language_embeddings = language_embeddings.transpose(0, 1).contiguous()
                if image_embeds is not None:
                    image_embeds = image_embeds.to(language_embeddings.device,
                                                   language_embeddings.dtype)
                    language_embeddings[image_input_mask] = image_embeds
                if video_embeds is not None:
                    video_embeds = video_embeds.to(language_embeddings.device,
                                                   language_embeddings.dtype)
                    language_embeddings[video_input_mask] = video_embeds
                language_embeddings = language_embeddings.transpose(0, 1).contiguous()
                combined_embeddings = language_embeddings

                cp_size = mpu.get_context_parallel_world_size()
                if combined_embeddings is not None and cp_size > 1:
                    combined_embeddings = split_data_cp_rank(combined_embeddings, cp_size, 0)

                if self.config.sequence_parallel:
                    combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(
                        combined_embeddings)
        else:
            combined_embeddings = None

        output = self.language_model(
            input_ids=None,
            position_ids=position_ids,                   
            attention_mask=attention_mask,                   
            decoder_input=combined_embeddings,                                               
            labels=labels,                                              
            inference_params=inference_params,                         
            packed_seq_params=packed_seq_params,                         
            **(extra_block_kwargs or {}),
        )
        return output
