#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

from typing import List, Optional, Tuple, Union
# from deepspeed.runtime.zero.partition_parameters import GatheredParameters
import torch
import torch.nn as nn
import torch.distributed as dist
# from transformers import AutoConfig, AutoModelForCausalLM, \
#                          LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM
from .modeling_llama import LlamaConfig, LlamaModel, LlamaForCausalLM
from .visual_mask import generate_mask,convert_to_image 

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM, RetrieverMetaModel,  LlavaMetaForCausalLMWithENT, POOLMETHODS 
from ...mm_utils import process_images

from common_utils.dist_utils import (
    all_gather_with_grad, 
    all_gather_with_grad_torch
)
from common_utils.loss_utils import contrastive_loss, contrastive_acc, contrastive_loss_with_target

import torch.nn.functional as F
import random
import sys 
torch.set_printoptions(threshold=sys.maxsize)


class LlavaConfig(LlamaConfig):
    model_type = "llava_llama"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModel, self).__init__(config)




class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)


        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
        cache_position=None
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes
            )

        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")

        if images is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                images,
                image_sizes=image_sizes
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)

        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        image_sizes = kwargs.pop("image_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            inputs['images'] = images
        if image_sizes is not None:
            inputs['image_sizes'] = image_sizes
        return inputs    





class LlavaLlamaModelwithEnt(LlavaMetaModel, LlamaModel, RetrieverMetaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModelwithEnt, self).__init__(config)
        # if hasattr(config, 'retrieval_shared_dim'):
        #     self.add_projector_module(config)


class LlavaLlamaForCausalLMwithEnt(LlamaForCausalLM, LlavaMetaForCausalLM, LlavaMetaForCausalLMWithENT):
    config_class = LlavaConfig

    def __init__(self, config):
        # super(LlamaForCausalLM, self).__init__(config)
        super().__init__(config)
        self.model = LlavaLlamaModelwithEnt(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    
    def get_ent_features(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_attentions: Optional[bool] = False, 
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
        project_mode: Optional[str] = None, 
        images_pixels: Optional[torch.FloatTensor] = None,
        cache_position=None, 
        signal =None,
        
    ):
        return_attention = False
        if self.config.add_image_features and "attn" in self.config.visual_feature_mode and project_mode in ["query", "context"]: 
            return_attention = True

        
        batch_ids, pos_ids = torch.where(input_ids == self.config.ent_token_id )
      
    
        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels, 
                additional_info
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes,
                images_pixels = images_pixels, 
                return_cls_image_features = True, 
                return_clip_image_features = getattr(self.config, 'use_clip_feature', False), 
            )
            cls_image_features = additional_info['cls_image_features']
            image_features = additional_info['image_features']
            if getattr(self.config, 'use_clip_feature', False):
                image_features = additional_info["clip_image_features"]
            visual_range = additional_info["visual_range"]
         

        #  <bound method LlamaForCausalLM.forward of LlavaLlamaForCausalLMwithEnt(
        output = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=return_attention,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict, 
            output_attentions_layer = getattr(self.config, 'attention_layer_idx', -1)
        )


    
        
        if hasattr(self.config,'llm_select_layer') and project_mode in ["query", "context"]: 
            hidden_states = output["hidden_states"][self.config.llm_select_layer]
        else: 
            hidden_states = output["hidden_states"][-1]


        
        batch_ids = batch_ids.to(hidden_states.device)
        pos_ids = pos_ids.to(hidden_states.device)

        
        token_hds = hidden_states[batch_ids, pos_ids]

        if hasattr(self.config, 'dual_projector') and not self.config.dual_projector and project_mode in ["query", "context"]:
            projecter = self.get_retr_projector()
        else:
            if project_mode == "query": 
                projecter = self.get_query_projector()
                # print("token_hds: ", token_hds.view(-1)[:5])
            elif project_mode == "context": 
                projecter = self.get_context_projector()
            elif project_mode == "autoreg":
                projecter = self.get_autoreg_projector()

        token_features = projecter(token_hds)
        
        
        # getattr(self.config, 'train_visual_processor', True)
        if self.config.add_image_features and project_mode in ["query", "context"]:
            token_features = F.normalize(token_features, p=2, dim=-1)
            if hasattr(self.config, 'visual_feature_mode'):
                if self.config.visual_feature_mode == "cls_hidden_states":
                    assert len(visual_range) == hidden_states.shape[0]
                    selected_hd_layer = output["hidden_states"][self.config.visual_hidden_states_layer]
                    cls_token_id = torch.tensor([inner[0][0] for inner in visual_range],  dtype=torch.long, device = selected_hd_layer.device)
                    batch_id = torch.arange(selected_hd_layer.shape[0], device = selected_hd_layer.device)
                    cls_image_features = selected_hd_layer[batch_id, cls_token_id]
                    cls_image_features =  self.get_model().get_mm_cls_projector()(cls_image_features)
                elif self.config.visual_feature_mode == "attnguidepool": 
                    cls_image_features = self.get_model().get_visual_processor()(
                        output["attentions"],
                        image_features,
                        visual_range, 
                        pos_ids
                    )
                elif self.config.visual_feature_mode == "hdpool": 
                    cls_image_features = self.get_model().get_visual_processor()(
                        hidden_states,
                        image_features,
                        visual_range, 
                        pos_ids
                    )
                elif self.config.visual_feature_mode == 'adapoolent': 
                    cls_image_features = self.get_model().get_visual_processor()(
                        image_features, 
                        token_features
                    )
                
                elif self.config.visual_feature_mode in POOLMETHODS:
                    cls_image_features = self.get_model().get_visual_processor()(image_features)
                    # print("using adapool ....")
                else:
                    cls_image_features =  self.get_model().get_mm_cls_projector()(cls_image_features)
            else:
                cls_image_features =  self.get_model().get_mm_cls_projector()(cls_image_features)
            
            
            cls_image_features = F.normalize(cls_image_features, p=2, dim=-1)

            
            if hasattr(self.config, 'use_visual_weight'): 
                if self.config.use_visual_weight: 
                    weight = self.get_model().get_visual_weight_projector()(token_features) 
                    cls_image_features = weight * cls_image_features
                    # print("using weight .... ")


            if hasattr(self.config, 'add_image_operation'): 
                if self.config.add_image_operation == "Concat": 
                   
                    token_features = torch.cat([token_features, cls_image_features], dim = 1)
                else: 
                    # token_features += cls_image_features
                    token_features = token_features + cls_image_features
            else:
                # token_features += cls_image_features
                token_features = token_features + cls_image_features



        token_features = F.normalize(token_features, dim = -1 )
       
        

        return output, token_features, token_hds



    def forward_contra(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
        cache_position=None, 

        # Positive samples
        positive_input_ids: Optional[torch.LongTensor] = None,
        positive_attention_mask: Optional[torch.Tensor] = None,
        positive_images: Optional[torch.FloatTensor] = None,
        positive_image_sizes: Optional[List[List[int]]] = None,

        # Negative samples
        negative_input_ids: Optional[torch.LongTensor] = None,
        negative_attention_mask: Optional[torch.Tensor] = None,
        negative_images: Optional[torch.FloatTensor] = None,
        negative_image_sizes: Optional[List[List[int]]] = None,

        image_input_pixels: Optional[torch.FloatTensor] = None,
        pos_input_pixels: Optional[torch.FloatTensor] = None,
        neg_input_pixels: Optional[torch.FloatTensor] = None,

    ) -> Union[Tuple, CausalLMOutputWithPast]:

        
        output_hidden_states = True 
        return_dict = True
        output_attentions = False
        output, token_features,_ = self.get_ent_features(
            input_ids, 
            attention_mask=attention_mask, 
            images=images, 
            images_pixels = image_input_pixels, 
            image_sizes= image_sizes, 
            output_attentions=output_attentions, 
            output_hidden_states=output_hidden_states, 
            cache_position=cache_position, 
            return_dict=return_dict, 
            project_mode = "query"
        )


        
        # Positive sample
        positive_output, positive_token_features, positive_token_hds = self.get_ent_features(
            positive_input_ids,
            attention_mask=positive_attention_mask,
            images=positive_images,
            image_sizes=positive_image_sizes,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            return_dict=return_dict,
            project_mode = "context", 
            images_pixels=pos_input_pixels
        )

        # Negative sample
        negative_output, negative_token_features, negative_token_hds = self.get_ent_features(
            negative_input_ids,
            attention_mask=negative_attention_mask,
            images=negative_images,
            image_sizes=negative_image_sizes,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            cache_position=cache_position,
            return_dict=return_dict,
            project_mode = "context", 
            images_pixels=neg_input_pixels
        )

        if self.config.gather_contrastive:
            # print("!!!!!!!!!! contra triggered")

            with torch.no_grad():
                self.model.itc_temp.clamp_(0.001, 0.5)
    

            # if dist.is_initialized(): 
            #     token_features_all = all_gather_with_grad_torch(token_features)
            #     positive_token_features_all = all_gather_with_grad_torch(positive_token_features)
            #     negative_token_features_all = all_gather_with_grad_torch(negative_token_features)
            combine_token_features = torch.cat([positive_token_features, negative_token_features], dim = 0)        


            # token_features = F.normalize(token_features, dim = -1)
            # token_features_2 = F.normalize(token_features_2, dim = -1)

            assert token_features.shape == positive_token_features.shape
            assert token_features.shape == negative_token_features.shape
            
            # ------------------------------------------------------- my code
            similarity = token_features @ combine_token_features.t() / self.model.itc_temp
            sim_targets = torch.zeros(similarity.size()).to(token_features.device)
            sim_targets.fill_diagonal_(1)
            loss =  -torch.sum(F.log_softmax(similarity, dim=1) * sim_targets, dim=1).mean()
            logits = None


            # ------------------------------------------------------- infoNCE loss
            # logits = token_features @ combine_token_features.t() / self.model.itc_temp
            # labels = torch.arange(token_features.size(0), device=token_features.device)
            # loss_i2t = F.cross_entropy(logits, labels)
            # logits_t = combine_token_features @ token_features.t() / self.model.itc_temp
            # loss_t2i = F.cross_entropy(logits_t[:labels.size(0)], labels)
            # loss = (loss_i2t + loss_t2i) / 2


            # ------------------------------------------------------- infoNCE with margin
            # compute similarity (query → context)
            # logits_i2t = token_features @ combine_token_features.t() / self.model.itc_temp
            # logits_t2i = combine_token_features @ token_features.t() / self.model.itc_temp
            # # margin to push positives stronger
            # margin = 0.2  # tune between 0.1–0.3 typically
            # # apply margin only to diagonal (positives)
            # batch_size = token_features.size(0)
            # diag_idx = torch.arange(batch_size, device=token_features.device)
            # logits_i2t[diag_idx, diag_idx] -= margin
            # logits_t2i[diag_idx, diag_idx] -= margin
            # # labels for matching pairs
            # labels = torch.arange(batch_size, device=token_features.device)
            # # cross-entropy InfoNCE in both directions
            # loss_i2t = F.cross_entropy(logits_i2t, labels)
            # loss_t2i = F.cross_entropy(logits_t2i[:batch_size], labels)
            # # final symmetric loss
            # loss = (loss_i2t + loss_t2i) / 2

            # output.loss = loss 
            # logits = output.logits

        else: 
            loss = None 
            logits = None

        return {
            "loss": loss, 
            "logits" : logits, 
            "positive_entity_hidden_states": positive_token_hds,
            "negative_entity_hidden_states":negative_token_hds
        }


    def forward_qa(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        # entity_embed: Optional[torch.FloatTensor] = None,
        entity_embed: Optional[Union[torch.Tensor, List[List[torch.Tensor]]]] = None, 
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        images_pixels: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        # inject the entity embed
        # entity embedding: batch size 


        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes, 
                entity_embedding = entity_embed
            )

        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
        cache_position=None, 

        # Positive samples
        positive_input_ids: Optional[torch.LongTensor] = None,
        positive_attention_mask: Optional[torch.Tensor] = None,
        positive_images: Optional[torch.FloatTensor] = None,
        positive_image_sizes: Optional[List[List[int]]] = None,

        # Negative samples
        negative_input_ids: Optional[torch.LongTensor] = None,
        negative_attention_mask: Optional[torch.Tensor] = None,
        negative_images: Optional[torch.FloatTensor] = None,
        negative_image_sizes: Optional[List[List[int]]] = None,

        # question answering samples
         # Autoreg positive samples
        autoreg_pos_input_ids: Optional[torch.LongTensor] = None,
        autoreg_pos_labels: Optional[torch.LongTensor] = None,
        autoreg_pos_attention_mask: Optional[torch.Tensor] = None,
        autoreg_pos_images: Optional[torch.FloatTensor] = None,
        autoreg_pos_image_sizes: Optional[List[List[int]]] = None,

        # Autoreg negative samples
        autoreg_neg_input_ids: Optional[torch.LongTensor] = None,
        autoreg_neg_labels: Optional[torch.LongTensor] = None,
        autoreg_neg_attention_mask: Optional[torch.Tensor] = None,
        autoreg_neg_images: Optional[torch.FloatTensor] = None,
        autoreg_neg_image_sizes: Optional[List[List[int]]] = None,

        autoreg_pos_correct_idx: Optional[List] = None,
        autoreg_neg_correct_idx: Optional[List] = None,


         # New Image Pixel Inputs (Pre-processed)
        image_input_pixels: Optional[torch.FloatTensor] = None,
        pos_input_pixels: Optional[torch.FloatTensor] = None,
        neg_input_pixels: Optional[torch.FloatTensor] = None,
        autoreg_pos_image_pixels: Optional[torch.FloatTensor] = None,
        autoreg_neg_image_pixels: Optional[torch.FloatTensor] = None,

        input_images = None, 
        input_pos_images = None, 
        input_neg_images = None
        
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        # compute masked image
        task = getattr(self.config, "task", "")
        if task == "rank_binary":
            loss = self.forward_rank(
                input_ids, 
                attention_mask=attention_mask, 
                images=images, 
                image_sizes=image_sizes, 
                return_dict=True,
                output_hidden_states=True,

                # Positive samples
                positive_input_ids=positive_input_ids,
                positive_attention_mask=positive_attention_mask,
                positive_images=positive_images,
                positive_image_sizes=positive_image_sizes,

                # Negative samples
                negative_input_ids=negative_input_ids,
                negative_attention_mask=negative_attention_mask,
                negative_images=negative_images,
                negative_image_sizes=negative_image_sizes,

                autoreg_pos_input_ids = autoreg_pos_input_ids,
                autoreg_pos_labels = autoreg_pos_labels, 
                autoreg_pos_attention_mask = autoreg_pos_attention_mask, 
                autoreg_pos_images = autoreg_pos_images, 
                autoreg_pos_image_sizes = autoreg_pos_image_sizes, 

                autoreg_neg_input_ids = autoreg_neg_input_ids, 
                autoreg_neg_labels = autoreg_neg_labels, 
                autoreg_neg_attention_mask = autoreg_neg_attention_mask, 
                autoreg_neg_images = autoreg_neg_images, 
                autoreg_neg_image_sizes = autoreg_neg_image_sizes, 

                


            )
            
            
            return {
                    "loss": loss, 
                    "autoreg_loss": None, 
                    "contra_loss":loss, 
                    "pos_autoreg_logits":None, 
                    "neg_autoreg_logits":None, 
                    "contrastive_logits":None
                }
        elif task == "ansrank": 
            # compute the autoreg part 
            _, _, positive_token_hds = self.get_ent_features(
                positive_input_ids,
                attention_mask=positive_attention_mask,
                images=positive_images,
                image_sizes=positive_image_sizes,
                output_attentions=output_attentions,
                output_hidden_states=True,
                cache_position=cache_position,
                return_dict=return_dict,
                project_mode = "context", 
                images_pixels=pos_input_pixels
            )

            positive_embed = positive_token_hds
            positive_embed = F.normalize(positive_embed, dim = -1 )
            positive_embed = positive_embed.unsqueeze(1)

            autoreg_pos_output = self.forward_qa(
                autoreg_pos_input_ids, 
                attention_mask = autoreg_pos_attention_mask, 
                entity_embed=positive_embed, 
                return_dict=True,
                labels = autoreg_pos_labels,
                images=autoreg_pos_images, 
                image_sizes = autoreg_pos_image_sizes
            )

            autoreg_loss = autoreg_pos_output.loss 
            loss = autoreg_loss

            return {
                    "loss": loss, 
                    "autoreg_loss": autoreg_loss, 
                    "contra_loss":None, 
                    "pos_autoreg_logits":None, 
                    "neg_autoreg_logits":None, 
                    "contrastive_logits":None
                }
        elif task == 'ansrank_noent': 
            autoreg_pos_output = self.forward_qa(
                autoreg_pos_input_ids, 
                attention_mask = autoreg_pos_attention_mask, 
                return_dict=True,
                labels = autoreg_pos_labels,
                images=autoreg_pos_images, 
                image_sizes = autoreg_pos_image_sizes
            )

            autoreg_loss = autoreg_pos_output.loss 
            loss = autoreg_loss

            return {
                    "loss": loss, 
                    "autoreg_loss": autoreg_loss, 
                    "contra_loss":None, 
                    "pos_autoreg_logits":None, 
                    "neg_autoreg_logits":None, 
                    "contrastive_logits":None
                }
            

            
        elif task == "noent": 
            loss = self.forward_noent_rank(
                input_ids= input_ids, 
                attention_mask=attention_mask, 
                image_sizes=image_sizes, 
                images = images, 
                target = autoreg_pos_labels, 
                return_dict=True
            )
            return {
                    "loss": loss, 
                    "autoreg_loss": loss, 
                    "contra_loss":None, 
                    "pos_autoreg_logits":None, 
                    "neg_autoreg_logits":None, 
                    "contrastive_logits":None
                }

        
        elif task == "rank_preEmbed":
           
            with torch.no_grad():
                embed_model =  self.get_model().get_ent_premodel()
                # TODO: compute positive and negative embed by get_ent_embed
                _, pos_token_features,  _ = embed_model.get_ent_features(
                    positive_input_ids, 
                    attention_mask=positive_attention_mask,
                    images=positive_images,
                    image_sizes=positive_image_sizes,
                    output_attentions=output_attentions,
                    output_hidden_states=True,
                    cache_position=cache_position,
                    return_dict=return_dict,
                    project_mode = "context", 
                    images_pixels=pos_input_pixels
                )
                _, neg_token_features, _ = embed_model.get_ent_features(
                    negative_input_ids,
                    attention_mask=negative_attention_mask,
                    images=negative_images,
                    image_sizes=negative_image_sizes,
                    output_attentions=output_attentions,
                    output_hidden_states=True,
                    cache_position=cache_position,
                    return_dict=return_dict,
                    project_mode = "context", 
                    images_pixels=neg_input_pixels
                )
            

            projector = self.get_model().get_pre_ent_token_projector()
            pos_token_features = projector(pos_token_features)
            neg_token_features = projector(neg_token_features)

            pos_token_features = pos_token_features.unsqueeze(1)
            neg_token_features = neg_token_features.unsqueeze(1)

            autoreg_pos_output = self.forward_qa(
                autoreg_pos_input_ids, 
                attention_mask = autoreg_pos_attention_mask, 
                entity_embed=pos_token_features, 
                return_dict=True,
                images=autoreg_pos_images, 
                image_sizes = autoreg_pos_image_sizes
            )

            autoreg_neg_output = self.forward_qa(
                autoreg_neg_input_ids, 
                attention_mask=autoreg_neg_attention_mask, 
                entity_embed=neg_token_features,  # use the negative sample embedding
                return_dict=True,
                images=autoreg_neg_images, 
                image_sizes=autoreg_neg_image_sizes
            )

            pos_loss = self.compute_yes_no_loss(autoreg_pos_output["logits"], autoreg_pos_labels, autoreg_pos_attention_mask)

            neg_loss = self.compute_yes_no_loss(autoreg_neg_output["logits"], autoreg_neg_labels, autoreg_neg_attention_mask)

            autoreg_loss = pos_loss + neg_loss
            loss = autoreg_loss
            return  {
                    "loss": loss, 
                    "autoreg_loss": autoreg_loss, 
                    "contra_loss":None, 
                    "pos_autoreg_logits":None, 
                    "neg_autoreg_logits":None, 
                    "contrastive_logits":None
                }


        else:
            if getattr(self.config, "use_visual_prompt", False): 
                images, positive_images, negative_images, image_sizes, positive_image_sizes, negative_image_sizes = self.compute_masked_image(
                    input_ids, 
                    attention_mask=attention_mask, 
                    images=images, 
                    image_sizes=image_sizes, 
                    return_dict=True,
                    output_hidden_states=True,

                    # Positive samples
                    positive_input_ids=positive_input_ids,
                    positive_attention_mask=positive_attention_mask,
                    positive_images=positive_images,
                    positive_image_sizes=positive_image_sizes,

                    # Negative samples
                    negative_input_ids=negative_input_ids,
                    negative_attention_mask=negative_attention_mask,
                    negative_images=negative_images,
                    negative_image_sizes=negative_image_sizes,

                    image_input_pixels = image_input_pixels,
                    pos_input_pixels = pos_input_pixels, 
                    neg_input_pixels = neg_input_pixels, 
                    input_images= input_images, 
                    input_pos_images = input_pos_images, 
                    input_neg_images = input_neg_images,  
                )
                


            contrastive_output = self.forward_contra(
                input_ids, 
                attention_mask=attention_mask, 
                images=images, 
                image_sizes=image_sizes, 
                return_dict=True,
                output_hidden_states=True,

                # Positive samples
                positive_input_ids=positive_input_ids,
                positive_attention_mask=positive_attention_mask,
                positive_images=positive_images,
                positive_image_sizes=positive_image_sizes,

                # Negative samples
                negative_input_ids=negative_input_ids,
                negative_attention_mask=negative_attention_mask,
                negative_images=negative_images,
                negative_image_sizes=negative_image_sizes,

                image_input_pixels = image_input_pixels,
                pos_input_pixels = pos_input_pixels, 
                neg_input_pixels = neg_input_pixels
            )


            positive_embed = contrastive_output["positive_entity_hidden_states"]
            negative_embed = contrastive_output["negative_entity_hidden_states"]
            # print("contrastive loss: ", contrastive_output["loss"])
            autoreg_loss = None
            if self.config.gather_autoreg:
                ########################## prepare embedding
                
                # projector = self.get_autoreg_projector()
                # positive_embed = projector(positive_embed)
                # negative_embed = projector(negative_embed)

                # normalize both 
                positive_embed = F.normalize(positive_embed, dim = -1 )
                negative_embed = F.normalize(negative_embed, dim = -1 )
                

                ########################### if multiple entity are involved 
                if autoreg_pos_correct_idx is not None: 

                    # depends on how many ent embeds in input ids
                    batch_ids, index_ids = torch.where(autoreg_pos_input_ids == self.config.ent_token_id)
                
                    grouped_batch_ids = []
                    grouped_index_ids = []
                    for batch_id, index_id in zip(batch_ids, index_ids):
                        if batch_id not in grouped_batch_ids: 
                            grouped_batch_ids.append(batch_id)
                            grouped_index_ids.append([])
                        grouped_index_ids[-1].append(index_id)
                    assert len(grouped_batch_ids) == positive_embed.shape[0]
                    autoreg_pos_ent_embeds = []
                    for i, batch_id in enumerate(grouped_batch_ids): 
                        index_1 = autoreg_pos_correct_idx[batch_id]
                        if index_1 == None: 
                            autoreg_pos_ent_embeds.append([positive_embed[batch_id]])
                            continue
                            


                        num_idx = len(grouped_index_ids[i]) # the actually index id does not matter, we just need how many embeddings are needed 

                        # find the place for negative embedding 
                        index_list = list(range(num_idx))
                        index_list.pop(index_1)
                        index_2 = random.choice(index_list)


                        positive_embed_without_batch = torch.cat((positive_embed[:batch_id], positive_embed[batch_id + 1:]))
                        negative_embed_without_batch = torch.cat((negative_embed[:batch_id], negative_embed[batch_id + 1:]))
                        
                        group_embed = []
                        for j in range(num_idx): 
                            if j == index_1: 
                                group_embed.append(positive_embed[batch_id])
                            elif j == index_2:
                                group_embed.append(negative_embed[batch_id])
                            else:
                                # randomly select 
                                embeds = random.choice([positive_embed_without_batch, negative_embed_without_batch])
                                random_index = random.randint(0, embeds.size(0) - 1)
                                embed = embeds[random_index]
                                group_embed.append(embed)
                        autoreg_pos_ent_embeds.append(group_embed)

                ##################
                if autoreg_neg_correct_idx is not None: 
                
                    batch_ids, index_ids = torch.where(autoreg_neg_input_ids == self.config.ent_token_id)
                    
                    grouped_batch_ids = []
                    grouped_index_ids = []
                    for batch_id, index_id in zip(batch_ids, index_ids):
                        if batch_id not in grouped_batch_ids: 
                            grouped_batch_ids.append(batch_id)
                            grouped_index_ids.append([])
                        grouped_index_ids[-1].append(index_id)
                    assert len(grouped_batch_ids) == negative_embed.shape[0]
                    autoreg_neg_ent_embeds = []
                    for i, batch_id in enumerate(grouped_batch_ids): 
                        index_1 = autoreg_neg_correct_idx[batch_id]
                        if index_1 == None: 
                            autoreg_neg_ent_embeds.append([negative_embed[batch_id]])
                            continue
                        num_idx = len(grouped_index_ids[i]) 
                        
                    
                        positive_embed_without_batch = torch.cat((positive_embed[:batch_id], positive_embed[batch_id + 1:]))
                        negative_embed_without_batch = torch.cat((negative_embed[:batch_id], negative_embed[batch_id + 1:]))
                        
                        group_embed = []
                        for j in range(num_idx): 
                            if j == index_1: 
                                group_embed.append(negative_embed[batch_id])
                            else:
                                # randomly select 
                                embeds = random.choice([positive_embed_without_batch, negative_embed_without_batch])
                                random_index = random.randint(0, embeds.size(0) - 1)
                                embed = embeds[random_index]
                                group_embed.append(embed)
                        autoreg_neg_ent_embeds.append(group_embed)

                else: 
                    autoreg_pos_ent_embeds = positive_embed.unsqueeze(1)
                    autoreg_neg_ent_embeds = negative_embed.unsqueeze(1)


                ###########################
                # TODO: delete the task == "rank_preEmbed" 
                if task == "rank_relprob" or task == "rank_preEmbed": 
                    autoreg_pos_output = self.forward_qa(
                        autoreg_pos_input_ids, 
                        attention_mask = autoreg_pos_attention_mask, 
                        entity_embed=autoreg_pos_ent_embeds, 
                        return_dict=True,
                        images=autoreg_pos_images, 
                        image_sizes = autoreg_pos_image_sizes
                    )

                    autoreg_neg_output = self.forward_qa(
                        autoreg_neg_input_ids, 
                        attention_mask=autoreg_neg_attention_mask, 
                        entity_embed=autoreg_neg_ent_embeds,  # use the negative sample embedding
                        return_dict=True,
                        images=autoreg_neg_images, 
                        image_sizes=autoreg_neg_image_sizes
                    )

                    pos_loss = self.compute_yes_no_loss(autoreg_pos_output["logits"], autoreg_pos_labels, autoreg_pos_attention_mask)

                    neg_loss = self.compute_yes_no_loss(autoreg_neg_output["logits"], autoreg_neg_labels, autoreg_neg_attention_mask)

                    autoreg_loss = pos_loss + neg_loss

                else: 
                    autoreg_pos_output = self.forward_qa(
                        autoreg_pos_input_ids, 
                        attention_mask = autoreg_pos_attention_mask, 
                        entity_embed=autoreg_pos_ent_embeds, 
                        labels = autoreg_pos_labels, 
                        return_dict=True,
                        images=autoreg_pos_images, 
                        image_sizes = autoreg_pos_image_sizes
                    )

                    autoreg_neg_output = self.forward_qa(
                        autoreg_neg_input_ids, 
                        attention_mask=autoreg_neg_attention_mask, 
                        entity_embed=autoreg_neg_ent_embeds,  # use the negative sample embedding
                        labels=autoreg_neg_labels, 
                        return_dict=True,
                        images=autoreg_neg_images, 
                        image_sizes=autoreg_neg_image_sizes
                    )


                    # if task== "rank"

                    autoreg_loss = autoreg_pos_output.loss + autoreg_neg_output.loss
            
            
                if self.config.gather_contrastive: 
                    loss = autoreg_loss  + contrastive_output["loss"]
                else: 
                    loss = autoreg_loss

                return {
                    "loss": loss, 
                    "autoreg_loss": autoreg_loss, 
                    "contra_loss":contrastive_output["loss"], 
                    "pos_autoreg_logits":autoreg_pos_output.logits, 
                    "neg_autoreg_logits":autoreg_neg_output.logits, 
                    "contrastive_logits":contrastive_output["logits"]
                }
            else: 
                # contrastive only 
                loss = contrastive_output["loss"]
                return {
                    "loss": loss, 
                    "autoreg_loss": None, 
                    "contra_loss":contrastive_output["loss"], 
                    "pos_autoreg_logits":None, 
                    "neg_autoreg_logits":None, 
                    "contrastive_logits":None
                }

    def compute_yes_no_loss(self, logits, target_id, attention_mask): 
        """
        target_id: 1 respresents pos, 0 respresnts neg
        """
        # no_logits = logits[]
        # find the last token in each batch 
        yes_token_id = 3869
        no_token_id = 1939
        loss_fct = torch.nn.CrossEntropyLoss()

        # print("logits shape: ", logits.shape)
        last_token_idx = attention_mask.sum(dim = 1) -1 
        batch_idx = torch.arange(attention_mask.size(0)).to(attention_mask.device)
        # Use advanced indexing with both batch and token indices
        last_token_logits = logits[batch_idx, last_token_idx, :]  # [batch_size, vocab_size]
        # logits shape:  torch.Size([16, 698, 32064])
        # last token idx:  torch.Size([16])
        yes_logits = last_token_logits[:, yes_token_id]
        no_logits = last_token_logits[:, no_token_id]

        binary_logits = torch.stack([no_logits, yes_logits], dim = 1)

        loss = loss_fct(binary_logits, target_id)
        return loss



    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        # entity_embed: Optional[torch.FloatTensor] = None,
        entity_embed: Optional[Union[torch.Tensor, List[List[torch.Tensor]]]] = None, 
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")

        if images is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                images,
                image_sizes=image_sizes, 
                entity_embedding = entity_embed
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)

        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        image_sizes = kwargs.pop("image_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            inputs['images'] = images
        if image_sizes is not None:
            inputs['image_sizes'] = image_sizes
        return inputs


AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLMwithEnt)