#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss, MSELoss
from typing import List, Optional, Tuple, Union
torch.manual_seed(42)

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM


class LlavaConfig(LlamaConfig):
    model_type = "llava_llama"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModel, self).__init__(config)


class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)


        self.noise_text = getattr(config, 'noise_text', False)

        self.dist = getattr(config, 'dist', False)
        self.select_layer = getattr(config, 'select_layer', 0)

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels,
                image_token_indices,
                visual_tokens,
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes
            )

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict


        if self.noise_text and self.dist and labels is not None and image_token_indices is not None:
            attention_mask2 = attention_mask
            output_hidden_states = True
            inputs_embeds2 = inputs_embeds.clone()

        if self.noise_text and labels is not None and image_token_indices is not None:
            for i in range(labels.shape[0]):
                if image_token_indices[i]:
                    label = labels[i]
                    text_mask = label!= -100
                    text_mask = text_mask.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
                    noise = torch.randn_like(inputs_embeds[i],dtype=inputs_embeds.dtype, device=inputs_embeds.device) * 0.01
                    inputs_embeds[i] = inputs_embeds[i] + noise * text_mask.unsqueeze(-1)

            if self.dist:
                inputs_embeds1 = inputs_embeds2
                inputs_embeds2 = inputs_embeds
                inputs_embeds = inputs_embeds1
        
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            # cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            lm_loss = loss_fct(shift_logits, shift_labels)
            
            eps = 1e-8
            vt_loss = 0
            dist_loss = 0
            lm_loss2 = 0
            if self.dist:  
                outputs2 = self.model2(
                    input_ids=input_ids,
                    attention_mask=attention_mask2,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    inputs_embeds=inputs_embeds2,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    # cache_position=cache_position,
                )
                hidden_states2 = outputs2[0]

                logits2 = self.lm_head(hidden_states2)
                logits2 = logits2.float()
                # Shift so that tokens < n predict n
                shift_logits2 = logits2[..., :-1, :].contiguous()
                # Flatten the tokens
                shift_logits2 = shift_logits2.view(-1, self.config.vocab_size)
                # Enable model parallelism
                shift_labels = shift_labels.to(shift_logits2.device)
                lm_loss2 = loss_fct(shift_logits2, shift_labels)

            if self.dist:
                cos_loss_fct = CosineEmbeddingLoss()
                mse_loss_fct = MSELoss()
                target = torch.ones(visual_tokens.shape[1]).to(visual_tokens.device)
                visual_tokens = visual_tokens.contiguous()
                num_images = eps
                for i in range(labels.shape[0]):
                    if image_token_indices[i]:
                        for j in range(len(image_token_indices[i])):
                            num_images += 1
                            select_layers = outputs.hidden_states[self.select_layer:]
                            select_layers2 = outputs2.hidden_states[self.select_layer:]
                            num_layers = len(select_layers)
                            hs_loss = 0
                            for k in range(num_layers):
                                image_token_index = image_token_indices[i][j] + 1
                                select_hidden_states = select_layers[k][i, image_token_index:image_token_index+576, :]
                                select_hidden_states2 = select_layers2[k][i, image_token_index:image_token_index+576, :]
                                layer_loss = cos_loss_fct(select_hidden_states, select_hidden_states2, target)
                                # layer_loss = mse_loss_fct(select_hidden_states, select_hidden_states2)
                                hs_loss += layer_loss
                            # hs_loss = hs_loss/num_layers
                            dist_loss += hs_loss

                    else:
                        num_images += 1
                        select_layers = outputs.hidden_states[self.select_layer:]
                        select_layers2 = outputs2.hidden_states[self.select_layer:]
                        num_layers = len(select_layers)
                        hs_loss = 0
                        for k in range(num_layers):
                            select_hidden_states = select_layers[k][i, 0, :]
                            select_hidden_states2 = select_layers2[k][i, 0, :]
                            layer_loss = cos_loss_fct(select_hidden_states, select_hidden_states, torch.tensor(1, device=select_hidden_states.device))
                            # layer_loss = mse_loss_fct(select_hidden_states, select_hidden_states)
                            hs_loss += layer_loss
                            layer_loss2 = cos_loss_fct(select_hidden_states2, select_hidden_states2, torch.tensor(1, device=select_hidden_states2.device))
                            # layer_loss2 = mse_loss_fct(select_hidden_states2, select_hidden_states2)
                            hs_loss += layer_loss2
                        # hs_loss = hs_loss/num_layers
                        dist_loss += hs_loss

                dist_loss = dist_loss / num_images
                
                
            print("lm loss", lm_loss)
            if self.dist: print("lm loss2", lm_loss2)
            if self.dist: print("dist loss", dist_loss)
            loss = lm_loss
            if self.dist: 
                loss += (0.5 * lm_loss2)
                loss += dist_loss
                    

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")

        if images is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _,
                _,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                images,
                image_sizes=image_sizes
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)

        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        image_sizes = kwargs.pop("image_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            inputs['images'] = images
        if image_sizes is not None:
            inputs['image_sizes'] = image_sizes
        return inputs

AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
