#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
import random
import sys

sys.path.append(".")  # Adds higher directory to python modules path.

from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

from transformers import AutoConfig, AutoModelForCausalLM, \
    LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig
from transformers.modeling_outputs import CausalLMOutputWithPast

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM

from transformers import LlamaModel as LlamaModelorg
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.utils import logging
logger = logging.get_logger(__name__)

from dscr_utils.dscr_refine import refine_past_key_values


class LlamaModel(LlamaModelorg):
    def custom_cosine_similarity(self, vec1, vec2, eps=1e-8):
        norm_vec1 = torch.sqrt(torch.sum(vec1 ** 2, dim=-1))
        norm_vec2 = torch.sqrt(torch.sum(vec2 ** 2, dim=-1))
        norm_vec1 = torch.clamp(norm_vec1, min=eps)
        norm_vec2 = torch.clamp(norm_vec2, min=eps)
        dot_product = torch.sum(vec1 * vec2, dim=-1)
        cosine_sim = dot_product / (norm_vec1 * norm_vec2)
        if torch.isinf(cosine_sim).any() or torch.isnan(cosine_sim).any():
            cosine_sim = torch.where(
                torch.isinf(cosine_sim) | torch.isnan(cosine_sim),
                torch.zeros_like(cosine_sim),
                cosine_sim,
            )
        return cosine_sim

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        lm_head=None,
        tau=None,
        beta_1=None,
        beta_2=None,
        alpha=None,
        damo_start_layer: Optional[int] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

        seq_length_with_past = seq_length
        past_key_values_length = 0

        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length

        if position_ids is None:
            device = input_ids.device if input_ids is not None else inputs_embeds.device
            position_ids = torch.arange(
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
            )
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
        else:
            position_ids = position_ids.view(-1, seq_length).long()

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
        )

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        use_damo = (
            (seq_length == 1)
            and output_hidden_states
            and tau is not None
            and beta_1 is not None
            and beta_2 is not None
            and alpha is not None
        )
        prev_hidden_states = 0
        delta_hidden_states = 0
        momentum_decoding_flag = False
        _damo_start = damo_start_layer if (damo_start_layer is not None and use_damo) else (len(self.layers) // 2 if use_damo else None)

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = past_key_values[idx] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, output_attentions, None)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    position_ids,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]
            if use_damo:
                prev_hidden = all_hidden_states[-1]
                if prev_hidden.device != hidden_states.device:
                    prev_hidden = prev_hidden.to(hidden_states.device)
                current_increasement = hidden_states[:, -1:, :] - prev_hidden[:, -1:, :]

            if use_damo and idx >= _damo_start:
                delta_momentum = alpha
                prev_hidden_states_momentum = beta_1
                logits_3 = lm_head(hidden_states[:, -1:, :])
                prev_hidden_1 = prev_hidden
                prev_hidden_2 = all_hidden_states[-2]
                if prev_hidden_2.device != hidden_states.device:
                    prev_hidden_2 = prev_hidden_2.to(hidden_states.device)
                logits_2 = lm_head(prev_hidden_1[:, -1:, :])
                logits_1 = lm_head(prev_hidden_2[:, -1:, :])
                P_3 = torch.nn.functional.softmax(logits_3, dim=-1)
                P_2 = torch.nn.functional.softmax(logits_2, dim=-1)
                P_1 = torch.nn.functional.softmax(logits_1, dim=-1)
                delta_p2 = P_3 - P_2
                delta_p1 = P_2 - P_1
                if delta_p1.device != hidden_states.device:
                    delta_p1 = delta_p1.to(hidden_states.device)
                if delta_p2.device != hidden_states.device:
                    delta_p2 = delta_p2.to(hidden_states.device)
                if not torch.is_tensor(delta_hidden_states) or delta_hidden_states.shape != delta_p1.shape:
                    delta_hidden_states = torch.zeros_like(delta_p1)
                if not torch.is_tensor(prev_hidden_states) or prev_hidden_states.shape != current_increasement.shape:
                    prev_hidden_states = torch.zeros_like(current_increasement)
                if delta_hidden_states.device != hidden_states.device:
                    delta_hidden_states = delta_hidden_states.to(hidden_states.device)
                if prev_hidden_states.device != hidden_states.device:
                    prev_hidden_states = prev_hidden_states.to(hidden_states.device)
                delta_hidden_states = delta_momentum * delta_hidden_states + (1 - delta_momentum) * delta_p1
                cosine_similarity = self.custom_cosine_similarity(delta_p2, delta_hidden_states)

                if cosine_similarity < tau:
                    momentum_decoding_flag = True
                    prev_hidden_states_momentum = beta_2
                prev_hidden_states = (
                    prev_hidden_states_momentum * prev_hidden_states
                    + (1 - prev_hidden_states_momentum) * current_increasement
                )
                if momentum_decoding_flag:
                    hidden_states[:, -1:, :] = hidden_states[:, -1:, :] - current_increasement + prev_hidden_states

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)
        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

class LlavaConfig(LlamaConfig):
    model_type = "llava"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModel, self).__init__(config)


class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    def _validate_model_kwargs(self, model_kwargs):
        # HALC/DoLA/OPERA kwargs are consumed by generate() directly and should not be
        # validated against forward() parameters (needed for vanilla transformers-4.31.0).
        _generate_only_keys = {
            "halc_decoding", "dola_decoding", "beam_search", "halc_assistant",
            "mature_layer", "base_layer", "candidate_premature_layers", "relative_top",
            "contrastive_decoding", "student_model",
            "opera_decoding", "key_position", "scale_factor", "threshold",
            "num_attn_candidates", "penalty_weights",
        }
        for k in _generate_only_keys:
            model_kwargs.pop(k, None)
        super()._validate_model_kwargs(model_kwargs)

    def generate(self, *args, **kwargs):
        first_logits = kwargs.pop("first_logits", None)
        if first_logits is None:
            return super().generate(*args, **kwargs)

        input_ids = kwargs.get("input_ids", None)
        if input_ids is None and len(args) > 0:
            input_ids = args[0]
        if input_ids is None:
            raise ValueError("first_logits path requires input_ids")

        past_key_values = kwargs.get("past_key_values", None)
        max_new_tokens = int(kwargs.get("max_new_tokens", 0))
        eos_token_id = kwargs.get("eos_token_id", self.generation_config.eos_token_id)

        if first_logits is None or past_key_values is None or max_new_tokens <= 0:
            return input_ids

        if isinstance(eos_token_id, (list, tuple)):
            eos_ids = [int(x) for x in eos_token_id]
        elif eos_token_id is None:
            eos_ids = []
        else:
            eos_ids = [int(eos_token_id)]

        cache = past_key_values
        logits = first_logits
        generated_tokens = []

        for _ in range(max_new_tokens):
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
            generated_tokens.append(next_token)

            if eos_ids:
                done_mask = torch.zeros_like(next_token, dtype=torch.bool)
                for eid in eos_ids:
                    done_mask = done_mask | (next_token == eid)
                if bool(done_mask.all()):
                    break

            with torch.inference_mode():
                out = self(
                    input_ids=next_token,
                    images=None,
                    use_cache=True,
                    past_key_values=cache,
                    return_dict=True,
                )
            cache = out.past_key_values
            logits = out.logits[:, -1, :]

        if not generated_tokens:
            return input_ids
        new_tokens = torch.cat(generated_tokens, dim=1)
        return torch.cat([input_ids, new_tokens], dim=1)

    def llava16_dscr_generate(self, input_ids, first_logits, cache, max_new_tokens, eos_token_id=None):
        if isinstance(eos_token_id, (list, tuple)):
            eos_ids = [int(x) for x in eos_token_id]
        elif eos_token_id is None:
            eos_ids = []
        else:
            eos_ids = [int(eos_token_id)]

        logits = first_logits
        rolling_cache = cache
        generated_tokens = []

        for _ in range(max_new_tokens):
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
            generated_tokens.append(next_token)

            if eos_ids:
                done_mask = torch.zeros_like(next_token, dtype=torch.bool)
                for eid in eos_ids:
                    done_mask = done_mask | (next_token == eid)
                if bool(done_mask.all()):
                    break

            with torch.inference_mode():
                out = self(
                    input_ids=next_token,
                    images=None,
                    use_cache=True,
                    past_key_values=rolling_cache,
                    return_dict=True,
                )
            rolling_cache = out.past_key_values
            logits = out.logits[:, -1, :]

        if not generated_tokens:
            return input_ids
        new_tokens = torch.cat(generated_tokens, dim=1)
        return torch.cat([input_ids, new_tokens], dim=1)

    def mix_image_features(self, image_features, option):
        if option == 1:
            shuffle_idx = torch.randperm(image_features.size(1))
            new_image_features = image_features[:, shuffle_idx, :]
        else:
            B, T, D = image_features.shape
            radius = 2
            swap_prob = 0.2
            new_image_features = image_features.clone()

            for b in range(B):
                for i in range(T):
                    if random.random() < swap_prob:
                        candidates = list(range(max(0, i - radius), min(T, i + radius + 1)))
                        candidates.remove(i)
                        if candidates:
                            j = random.choice(candidates)
                            tmp = new_image_features[b, i].clone()
                            new_image_features[b, i] = new_image_features[b, j]
                            new_image_features[b, j] = tmp

        return new_image_features

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            images: Optional[torch.FloatTensor] = None,
            images_cd: Optional[torch.FloatTensor] = None,
            cd_beta: Optional[torch.FloatTensor] = None,
            cd_alpha: Optional[torch.FloatTensor] = None,
            return_dict: Optional[bool] = None,
            tau=None,
            beta_1=None,
            beta_2=None,
            alpha=None,
            damo_start_layer: Optional[int] = None,
            dscr_depth: Optional[torch.Tensor] = None,
            dscr_image_start: Optional[int] = None,
            dscr_image_len: Optional[int] = None,
            dscr_alpha: Optional[float] = None,
            dscr_beta: Optional[float] = None,
            dscr_sigma: Optional[float] = None,
            dscr_keep_ratio: Optional[float] = None,
            dscr_lambda: Optional[float] = None,
            dscr_start_layer: Optional[int] = None,
            dscr_end_layer: Optional[int] = None,
            dscr_key_only: bool = False,
            dscr_value_only: bool = False,
            dscr_key_value: bool = True,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        raw_input_ids = input_ids
        _caller_inputs_embeds = inputs_embeds  # save before prepare_inputs overwrites it
        input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(
            input_ids, attention_mask, past_key_values, labels, images)
        # When HALC (or any caller) passes inputs_embeds directly without input_ids,
        # prepare_inputs_labels_for_multimodal returns None for inputs_embeds (early-exit path).
        # Restore the caller's inputs_embeds so self.model() receives something valid.
        if inputs_embeds is None and _caller_inputs_embeds is not None:
            inputs_embeds = _caller_inputs_embeds

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            lm_head=self.lm_head,
            tau=tau,
            beta_1=beta_1,
            beta_2=beta_2,
            alpha=alpha,
            damo_start_layer=damo_start_layer
        )

        if dscr_depth is not None and past_key_values is None:
            if dscr_image_start is None or dscr_image_len is None:
                raise ValueError("DSCR requires dscr_image_start and dscr_image_len")
            updated_past = refine_past_key_values(
                outputs.past_key_values,
                dscr_depth,
                dscr_image_start,
                dscr_image_len,
                float(dscr_alpha),
                float(dscr_beta),
                float(dscr_sigma),
                float(dscr_keep_ratio),
                float(dscr_lambda),
                int(dscr_start_layer),
                int(dscr_end_layer) if dscr_end_layer is not None else None,
                bool(dscr_key_only),
                bool(dscr_value_only),
                bool(dscr_key_value),
            )
            outputs.past_key_values = updated_past

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model/pipeline parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images", None),
                "tau": kwargs.get("tau", None),
                "beta_1": kwargs.get("beta_1", None),
                "beta_2": kwargs.get("beta_2", None),
                "alpha": kwargs.get("alpha", None),
                "damo_start_layer": kwargs.get("damo_start_layer", None),
                "dscr_depth": kwargs.get("dscr_depth", None),
                "dscr_image_start": kwargs.get("dscr_image_start", None),
                "dscr_image_len": kwargs.get("dscr_image_len", None),
                "dscr_alpha": kwargs.get("dscr_alpha", None),
                "dscr_beta": kwargs.get("dscr_beta", None),
                "dscr_sigma": kwargs.get("dscr_sigma", None),
                "dscr_keep_ratio": kwargs.get("dscr_keep_ratio", None),
                "dscr_lambda": kwargs.get("dscr_lambda", None),
                "dscr_start_layer": kwargs.get("dscr_start_layer", None),
                "dscr_end_layer": kwargs.get("dscr_end_layer", None),
                "dscr_key_only": kwargs.get("dscr_key_only", None),
                "dscr_value_only": kwargs.get("dscr_value_only", None),
                "dscr_key_value": kwargs.get("dscr_key_value", None),
            }
        )
        return model_inputs

    def prepare_inputs_for_generation_cd(
            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images_cd", None),
                "dscr_depth": kwargs.get("dscr_depth", None),
                "dscr_image_start": kwargs.get("dscr_image_start", None),
                "dscr_image_len": kwargs.get("dscr_image_len", None),
                "dscr_alpha": kwargs.get("dscr_alpha", None),
                "dscr_beta": kwargs.get("dscr_beta", None),
                "dscr_sigma": kwargs.get("dscr_sigma", None),
                "dscr_keep_ratio": kwargs.get("dscr_keep_ratio", None),
                "dscr_lambda": kwargs.get("dscr_lambda", None),
                "dscr_start_layer": kwargs.get("dscr_start_layer", None),
                "dscr_end_layer": kwargs.get("dscr_end_layer", None),
                "dscr_key_only": kwargs.get("dscr_key_only", None),
                "dscr_value_only": kwargs.get("dscr_value_only", None),
                "dscr_key_value": kwargs.get("dscr_key_value", None),
            }
        )
        return model_inputs

    def prepare_inputs_for_generation_opera(
            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "images": kwargs.get("images_cd", None),
                "dscr_depth": kwargs.get("dscr_depth", None),
                "dscr_image_start": kwargs.get("dscr_image_start", None),
                "dscr_image_len": kwargs.get("dscr_image_len", None),
                "dscr_alpha": kwargs.get("dscr_alpha", None),
                "dscr_beta": kwargs.get("dscr_beta", None),
                "dscr_sigma": kwargs.get("dscr_sigma", None),
                "dscr_keep_ratio": kwargs.get("dscr_keep_ratio", None),
                "dscr_lambda": kwargs.get("dscr_lambda", None),
                "dscr_start_layer": kwargs.get("dscr_start_layer", None),
                "dscr_end_layer": kwargs.get("dscr_end_layer", None),
                "dscr_key_only": kwargs.get("dscr_key_only", None),
                "dscr_value_only": kwargs.get("dscr_value_only", None),
                "dscr_key_value": kwargs.get("dscr_key_value", None),
            }
        )
        return model_inputs

AutoConfig.register("llava", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)