import logging
from collections import namedtuple
from functools import partial
from typing import List, Optional

from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.models.gpt import GPTModel
from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.models.vision.radio import RADIOViTModel
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import log_single_rank
try:
    from megatron.core.models.multimodal.llava_model import DEFAULT_IMAGE_TOKEN_INDEX, is_te_min_version, _load_state_dict_hook_ignore_param_names
except:
    pass
from megatron.core.models.multimodal.llava_model import LLaVAModel


def llava_model_init(
    self,
    language_transformer_config: TransformerConfig,
    language_transformer_layer_spec: ModuleSpec,
    language_vocab_size: int,
    language_max_sequence_length: int,
    vision_transformer_config: TransformerConfig,
    vision_transformer_layer_spec: ModuleSpec,
    drop_vision_class_token: bool,
    vision_projection_config: TransformerConfig,
    vision_projection_layer_spec: ModuleSpec,
    vision_projection_type: str = "mlp",
    allow_missing_vision_projection_checkpoint: bool = False,
    parallel_output: bool = True,
    share_embeddings_and_output_weights: bool = False,
    language_position_embedding_type: str = 'learned_absolute',
    language_rotary_percent: float = 1.0,
    pre_process: bool = True,
    post_process: bool = True,
    add_encoder: bool = True,
    add_decoder: bool = True,
    img_h: int = 336,
    img_w: int = 336,
    patch_dim: int = 14,
    language_rotary_base: int = 10000,
    language_rope_scaling: bool = False,
    language_rope_scaling_factor: float = 8.0,
    image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX,
    pixel_shuffle: bool = False,
    tile_tags: Optional[list] = None,
                   
    text_model_cls: type = GPTModel,
                
) -> None:
    super(LLaVAModel, self).__init__(config=language_transformer_config)

    if has_config_logger_enabled(language_transformer_config):
        log_config_to_disk(language_transformer_config, locals(), prefix=type(self).__name__)

    log_single_rank(
        logging.getLogger(__name__),
        logging.WARNING,
        "LLaVA is work in progress. Features are missing and methods can change.",
    )

    self.pre_process = pre_process
    self.post_process = post_process
    self.add_encoder = add_encoder
    self.add_decoder = add_decoder

    self.encoder_hidden_state = None
    self.vision_model = None
    self.vision_projection = None
    self.language_model = None

    self.sequence_parallel_lm = language_transformer_config.sequence_parallel
    self.tp_comm_overlap_lm = language_transformer_config.tp_comm_overlap
    self.context_parallel_lm = language_transformer_config.context_parallel_size
    if self.sequence_parallel_lm or self.context_parallel_lm > 1:
                       
                  
                                                                                                 
                                      
                         
                                                                                          
                     
        if self.context_parallel_lm > 1:
            assert is_te_min_version(
                "1.10.0"
            ), "Context Parallelism in LLaVA requires TE v1.10 or higher"
    self.tensor_model_parallel_size_lm = language_transformer_config.tensor_model_parallel_size

                                                                    
                                                                                           
    self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
    if self.add_decoder:
                       
        self.language_model = text_model_cls(
            config=language_transformer_config,
            transformer_layer_spec=language_transformer_layer_spec,
            vocab_size=language_vocab_size,
            max_sequence_length=language_max_sequence_length,
            parallel_output=parallel_output,
            share_embeddings_and_output_weights=share_embeddings_and_output_weights,
            position_embedding_type=language_position_embedding_type,
            rotary_percent=language_rotary_percent,
            pre_process=self.pre_process,
            post_process=self.post_process,
            rotary_base=language_rotary_base,
            rope_scaling=language_rope_scaling,
            rope_scaling_factor=language_rope_scaling_factor,
            scatter_embedding_sequence_parallel=False,
        )
                     
        self.share_embeddings_and_output_weights = (
            self.language_model.share_embeddings_and_output_weights
        )
        self._language_max_sequence_length = language_max_sequence_length
        self._language_is_pipeline_parallel = (
            language_transformer_config.pipeline_model_parallel_size > 1
        )

    class_token_len = 1
    if self.add_encoder:
        self._drop_vision_class_token = drop_vision_class_token
        add_class_token = True
        if vision_transformer_config.vision_model_type.startswith(
            ("clip", "siglip", "internvit")
        ):
            if vision_transformer_config.vision_model_type == "siglip":
                class_token_len = 0
                add_class_token = False
                error_msg = (
                    "Siglip does not support vision class token, "
                    "set disable-vision-class-token to False."
                )
                assert not self._drop_vision_class_token, error_msg
            self.vision_model = CLIPViTModel(
                vision_transformer_config,
                vision_transformer_layer_spec,
                img_h=img_h,
                img_w=img_w,
                class_token_len=class_token_len,
                patch_dim=patch_dim,
                model_subtype=vision_transformer_config.vision_model_type,
                add_class_token=add_class_token,
            )
        elif vision_transformer_config.vision_model_type in ("radio"):
                                                           
            class_token_len = 8
            max_img_h = 2048
            max_img_w = 2048
            embedder_bias = False
            use_mask_token = False
            self.vision_model = RADIOViTModel(
                vision_transformer_config,
                vision_transformer_layer_spec,
                img_h=img_h,
                img_w=img_w,
                max_img_h=max_img_h,
                max_img_w=max_img_w,
                class_token_len=class_token_len,
                patch_dim=patch_dim,
                add_class_token=add_class_token,
                embedder_bias=embedder_bias,
                use_mask_token=use_mask_token,
            )
        else:
            raise ValueError(
                "Vision model "
                f"{vision_transformer_config.vision_model_type} is not "
                "supported."
            )

        vision_projection_input_size = vision_transformer_config.hidden_size
        vision_projection_input_size *= 4 if pixel_shuffle else 1

                                                                                        
        self.vision_projection = MultimodalProjector(
            vision_projection_config,
            vision_projection_layer_spec,
            vision_projection_type,
            vision_projection_input_size,
        )
                                                                                     
                                                                                           
                                                                                        
                                           
        if allow_missing_vision_projection_checkpoint:
            vision_projection_param_names = [
                f"vision_projection.{name}"
                for name in self.vision_projection.state_dict().keys()
            ]
            self.vision_projection.register_load_state_dict_post_hook(
                partial(_load_state_dict_hook_ignore_param_names, vision_projection_param_names)
            )

    self._img_seq_len = get_num_image_embeddings(
        img_h,
        img_w,
        patch_dim,
        vision_transformer_config.vision_model_type,
        drop_vision_class_token,
        class_token_len,
        pixel_shuffle,
        tile_tags is not None,                               
    )

    self.image_token_index = image_token_index
    self._pixel_shuffle = pixel_shuffle
    self._tile_tags = tile_tags
