
from typing import Optional, List

import torch

from transformers import AutoModelForCausalLM, LlamaConfig

from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from llava.model.language_model.llava_llama import LlavaConfig
from llava.backbone.detr_branch.model_llama_meta import LlamaForCausalLM_Meta, LlamaModel_Meta


class LlavaVisionForCausalLM(LlavaMetaForCausalLM, LlamaForCausalLM_Meta):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM_Meta, self).__init__(config)

        print("Init LlavaVisionForCausalLM")
        
        self.model = LlavaVisionModel(config)
        

    def get_model(self):
        return self.model
    

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ):

        input_ids, attention_mask, past_key_values, inputs_embeds, labels \
            = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels=None, images=images)
        
        if inputs_embeds is None:
            raise ValueError("inputs_embeds is None!")
        
        return input_ids, attention_mask, past_key_values, inputs_embeds, labels
    

class LlavaVisionModel(LlavaMetaModel, LlamaModel_Meta):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaVisionModel, self).__init__(config)




AutoModelForCausalLM.register(LlavaConfig, LlavaVisionForCausalLM)


