# ########
# ## For baseline and VCD and maybe OPERA
# #######
# #    Copyright 2023 Haotian Liu
# #
# #    Licensed under the Apache License, Version 2.0 (the "License");
# #    you may not use this file except in compliance with the License.
# #    You may obtain a copy of the License at
# #
# #        http://www.apache.org/licenses/LICENSE-2.0
# #
# #    Unless required by applicable law or agreed to in writing, software
# #    distributed under the License is distributed on an "AS IS" BASIS,
# #    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# #    See the License for the specific language governing permissions and
# #    limitations under the License.
# import sys 
# sys.path.append(".") # Adds higher directory to python modules path.

# from typing import List, Optional, Tuple, Union

# import torch
# import torch.nn as nn
# from torch.nn import CrossEntropyLoss

# from transformers import AutoConfig, AutoModelForCausalLM, \
#                          LlamaConfig, LlamaModel, LlamaForCausalLM

# from transformers.modeling_outputs import CausalLMOutputWithPast

# from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM


# class LlavaConfig(LlamaConfig):
#     model_type = "llava"


# class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
#     config_class = LlavaConfig

#     def __init__(self, config: LlamaConfig):
#         super(LlavaLlamaModel, self).__init__(config)


# class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
#     config_class = LlavaConfig

#     def __init__(self, config):
#         super(LlamaForCausalLM, self).__init__(config)
#         self.model = LlavaLlamaModel(config)

#         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

#         # Initialize weights and apply final processing
#         self.post_init()

#     def get_model(self):
#         return self.model

#     def forward(
#         self,
#         input_ids: torch.LongTensor = None,
#         attention_mask: Optional[torch.Tensor] = None,
#         past_key_values: Optional[List[torch.FloatTensor]] = None,
#         inputs_embeds: Optional[torch.FloatTensor] = None,
#         labels: Optional[torch.LongTensor] = None,
#         use_cache: Optional[bool] = None,
#         output_attentions: Optional[bool] = None,
#         output_hidden_states: Optional[bool] = None,
#         images: Optional[torch.FloatTensor] = None,
#         #########################BEGIN FOR OPERA######################
#         images_cd: Optional[torch.FloatTensor] = None,
#         cd_beta: Optional[torch.FloatTensor] = None,
#         cd_alpha: Optional[torch.FloatTensor] = None,
#         #####################END FOR OPERA###############################
#         ##############################BEGIN FOR DOLA#################
#         dola_decoding: Optional[bool] = None,
#         mature_layer: Optional[int] = None,
#         base_layer: Optional[int] = None,
#         candidate_premature_layers: Optional[List[int]] = None,
#         relative_top: Optional[float] = 0.1,
#         contrastive_decoding: Optional[bool] = None,
#         student_model = None,
#         early_exit_layers=None,
#         #############################END FOR DOLA#########################

#         return_dict: Optional[bool] = None,
#     ) -> Union[Tuple, CausalLMOutputWithPast]:
#         output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
#         output_hidden_states = (
#             output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
#         )
#         return_dict = return_dict if return_dict is not None else self.config.use_return_dict

#         input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
#         # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
#         # print(inputs_embeds.shape)
#         # print(input_ids.shape)
#         outputs = self.model(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#             past_key_values=past_key_values,
#             inputs_embeds=inputs_embeds,
#             use_cache=use_cache,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#             return_dict=return_dict
#         )

        
#         ###########################BEGIN FOR DOLA#####################
#         if early_exit_layers is not None:
#             logits_dict = {}
#             for i, early_exit_layer in enumerate(early_exit_layers):
#                 logits = self.lm_head(outputs.hidden_states[early_exit_layer])
#                 logits_dict[early_exit_layer] = logits
#             loss = None
#             if labels is not None:
#                 # Shift so that tokens < n predict n
#                 shift_logits = logits[..., :-1, :].contiguous()
#                 shift_labels = labels[..., 1:].contiguous()
#                 # Flatten the tokens
#                 loss_fct = CrossEntropyLoss()
#                 shift_logits = shift_logits.view(-1, self.config.vocab_size)
#                 shift_labels = shift_labels.view(-1)
#                 # Enable model parallelism
#                 shift_labels = shift_labels.to(shift_logits.device)
#                 loss = loss_fct(shift_logits, shift_labels)
                
#             final_outputs = CausalLMOutputWithPast(
#                 loss=loss,
#                 logits=logits,
#                 past_key_values=outputs.past_key_values,
#                 hidden_states=outputs.hidden_states,
#                 attentions=outputs.attentions,
#             )
#             return logits_dict, final_outputs
#         ###########################END FOR DOLA############################
#         else:

#             hidden_states = outputs[0]
#             logits = self.lm_head(hidden_states)

#             loss = None
#             if labels is not None:
#                 # Shift so that tokens < n predict n
#                 shift_logits = logits[..., :-1, :].contiguous()
#                 shift_labels = labels[..., 1:].contiguous()
#                 # Flatten the tokens
#                 loss_fct = CrossEntropyLoss()
#                 shift_logits = shift_logits.view(-1, self.config.vocab_size)
#                 shift_labels = shift_labels.view(-1)
#                 # Enable model/pipeline parallelism
#                 shift_labels = shift_labels.to(shift_logits.device)
#                 loss = loss_fct(shift_logits, shift_labels)

#             if not return_dict:
#                 output = (logits,) + outputs[1:]
#                 return (loss,) + output if loss is not None else output

#             return CausalLMOutputWithPast(
#                 loss=loss,
#                 logits=logits,
#                 past_key_values=outputs.past_key_values,
#                 hidden_states=outputs.hidden_states,
#                 attentions=outputs.attentions,
#             )

#     def prepare_inputs_for_generation(
#         self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
#     ):
#         if past_key_values:
#             input_ids = input_ids[:, -1:]

#         # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
#         if inputs_embeds is not None and past_key_values is None:
#             model_inputs = {"inputs_embeds": inputs_embeds}
#         else:
#             model_inputs = {"input_ids": input_ids}

#         model_inputs.update(
#             {
#                 "past_key_values": past_key_values,
#                 "use_cache": kwargs.get("use_cache"),
#                 "attention_mask": attention_mask,
#                 "images": kwargs.get("images", None),
#             }
#         )
#         return model_inputs
    
#     def prepare_inputs_for_generation_cd(
#         self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
#     ):
#         if past_key_values:
#             input_ids = input_ids[:, -1:]

#         # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
#         if inputs_embeds is not None and past_key_values is None:
#             model_inputs = {"inputs_embeds": inputs_embeds}
#         else:
#             model_inputs = {"input_ids": input_ids}

#         model_inputs.update(
#             {
#                 "past_key_values": past_key_values,
#                 "use_cache": kwargs.get("use_cache"),
#                 "attention_mask": attention_mask,
#                 "images": kwargs.get("images_cd", None),
#             }
#         )
#         return model_inputs

# AutoConfig.register("llava", LlavaConfig)
# AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)






##################################

## For momentum decoding

##################################
import sys 
sys.path.append(".") # Adds higher directory to python modules path.

from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
from transformers.modeling_outputs import CausalLMOutputWithPast

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from transformers import LlamaModel as LlamaModelorg
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.utils import logging
logger = logging.get_logger(__name__)


class LlamaModel(LlamaModelorg):
    
    def custom_cosine_similarity(self, vec1, vec2, eps=1e-8):
        norm_vec1 = torch.sqrt(torch.sum(vec1 ** 2, dim=-1))
        norm_vec2 = torch.sqrt(torch.sum(vec2 ** 2, dim=-1))
        norm_vec1 = torch.clamp(norm_vec1, min=eps)
        norm_vec2 = torch.clamp(norm_vec2, min=eps)
        dot_product = torch.sum(vec1 * vec2, dim=-1)
        cosine_sim = dot_product / (norm_vec1 * norm_vec2)
        # cosine_sim = torch.clamp(cosine_sim, min=-1.0, max=1.0)
        if torch.isinf(cosine_sim).any() or torch.isnan(cosine_sim).any():
            cosine_sim = torch.where(torch.isinf(cosine_sim) | torch.isnan(cosine_sim), torch.zeros_like(cosine_sim), cosine_sim)
        
        return cosine_sim
    
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        lm_head=None
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        prev_hidden_states = 0
        delta_hidden_states = 0
        momentum_decoding_flag = False

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, output_attentions, None)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            current_incre = hidden_states[:,-1:,:]-all_hidden_states[-1][:,-1:,:]


            if idx>=16:   ##     1515.89, 304.64 For MME
                delta_momentum = 0.7
                prev_hidden_states_momentum = 0.05
                logits_3 = lm_head(hidden_states[:,-1:,:])
                logits_2 = lm_head(all_hidden_states[-1][:,-1:,:])
                logits_1 = lm_head(all_hidden_states[-2][:,-1:,:])
                P_3 = torch.nn.functional.softmax(logits_3, dim=-1)
                P_2 = torch.nn.functional.softmax(logits_2, dim=-1)
                P_1 = torch.nn.functional.softmax(logits_1, dim=-1)
                delta_p2 = P_3-P_2
                delta_p1 = P_2-P_1
                delta_hidden_states = delta_momentum * delta_hidden_states + (1-delta_momentum) * delta_p1
                cosine_similarity = self.custom_cosine_similarity(delta_p2, delta_hidden_states)

                if cosine_similarity<-0.3:
                    momentum_decoding_flag=True
                    prev_hidden_states_momentum = 0.2
                prev_hidden_states = prev_hidden_states_momentum * prev_hidden_states + (1-prev_hidden_states_momentum) * current_incre
                if momentum_decoding_flag:
                    hidden_states[:, -1:, :] = hidden_states[:, -1:, :] - current_incre + prev_hidden_states


            # if idx>=16:   ##POPE
            #     delta_momentum = 0.7
            #     prev_hidden_states_momentum = 0.2
            #     logits_3 = lm_head(hidden_states[:,-1:,:])
            #     logits_2 = lm_head(all_hidden_states[-1][:,-1:,:])
            #     logits_1 = lm_head(all_hidden_states[-2][:,-1:,:])
            #     P_3 = torch.nn.functional.softmax(logits_3, dim=-1)
            #     P_2 = torch.nn.functional.softmax(logits_2, dim=-1)
            #     P_1 = torch.nn.functional.softmax(logits_1, dim=-1)
            #     delta_p2 = P_3-P_2
            #     delta_p1 = P_2-P_1
            #     delta_hidden_states = delta_momentum * delta_hidden_states + (1-delta_momentum) * delta_p1
            #     cosine_similarity = self.custom_cosine_similarity(delta_p2, delta_hidden_states)

            #     if cosine_similarity<-0.3:
            #         momentum_decoding_flag=True
            #         prev_hidden_states_momentum = 0.4
            #     prev_hidden_states = prev_hidden_states_momentum * prev_hidden_states + (1-prev_hidden_states_momentum) * current_incre
            #     if momentum_decoding_flag:
            #         hidden_states[:, -1:, :] = hidden_states[:, -1:, :] - current_incre + prev_hidden_states


            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)
        hidden_states = self.norm(hidden_states)


        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class LlamaForCausalLM(LlamaForCausalLMOrig):

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            lm_head = self.lm_head
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)


        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class LlavaConfig(LlamaConfig):
    model_type = "llava"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModel, self).__init__(config)


class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        images_cd: Optional[torch.FloatTensor] = None,
        cd_beta: Optional[torch.FloatTensor] = None,
        cd_alpha: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            lm_head=self.lm_head
        )

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model/pipeline parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images", None),
            }
        )
        return model_inputs
    
    def prepare_inputs_for_generation_cd(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images_cd", None),
            }
        )
        return model_inputs

AutoConfig.register("llava", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
