from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F

from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
from .modeling_llama import LlamaModel, LlamaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ...constants import IMAGE_TOKEN_PATCH, IGNORE_INDEX

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM


class LlavaConfig(LlamaConfig):
    model_type = "llava_llama"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModel, self).__init__(config)

class CausalLMOutputWithPastAddWeight(CausalLMOutputWithPast):
    def __init__(self, loss=None, logits=None, past_key_values=None, hidden_states=None, attentions=None, weight=None):
        super().__init__(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states, attentions=attentions)
        self.weight = weight

    def to_tuple(self) -> Tuple:
        return super().to_tuple() + (self.weight,)
    
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
        cache_position=None,
        contrastive_grad:bool=False,
        contrastive_logits:bool=False,
        only_contrast:bool=False,
        only_ref:bool=False,
        sample_index = None,
        output_attention_statistics: Optional[bool] = False, 
    # ) -> Union[Tuple, CausalLMOutputWithPast]:
    ) -> Union[Tuple, CausalLMOutputWithPastAddWeight]:

        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes
            )

        return self.base_forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            contrastive_grad=contrastive_grad,
            contrastive_logits=contrastive_logits,
            only_contrast=only_contrast,
            only_ref=only_ref,
            sample_index=sample_index,
            output_attention_statistics=output_attention_statistics 
        )
        
    def base_forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        contrastive_grad:bool=False,
        contrastive_logits:bool=False,
        only_contrast:bool=False,
        only_ref:bool=False,
        sample_index = None,
        output_attention_statistics: Optional[bool] = False, 
    # ) -> Union[Tuple, CausalLMOutputWithPast]:
    ) -> Union[Tuple, CausalLMOutputWithPastAddWeight]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        labels_0=labels.clone() if labels is not None else None
        uncertainty=None
        if only_contrast:
            with torch.no_grad():
                index = labels == IMAGE_TOKEN_PATCH
                attention_mask_wo_img = attention_mask.clone()
                ## mask image tokens
                attention_mask_wo_img[index] = False
                inputs_embeds_wo_img = inputs_embeds.clone()
                inputs_embeds_wo_img = inputs_embeds_wo_img.contiguous()
                labels[index] = IGNORE_INDEX
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask_wo_img,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    inputs_embeds=inputs_embeds_wo_img,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    output_attention_statistics=output_attention_statistics, 
                    labels=labels_0, 
                )

                hidden_states = outputs[0]
                if self.pretraining_tp > 1:
                    lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
                    logits_wo_img = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
                    logits_wo_img = torch.cat(logits_wo_img, dim=-1)
                else:
                    logits_wo_img = self.lm_head(hidden_states)
                logits = logits_wo_img.float()
                
            return CausalLMOutputWithPast(
                loss=None,
                logits=logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
        if only_ref:
            with torch.no_grad():
                outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        past_key_values=past_key_values,
                        inputs_embeds=inputs_embeds,
                        use_cache=use_cache,
                        output_attentions=output_attentions,
                        output_hidden_states=output_hidden_states,
                        return_dict=return_dict,
                        output_attention_statistics=output_attention_statistics, 
                        labels=labels_0, 
                    )

                hidden_states = outputs[0]
                if self.pretraining_tp > 1:
                    lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
                    logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
                    logits = torch.cat(logits, dim=-1)

                else:
                    logits = self.lm_head(hidden_states)

                logits = logits.float()
            return CausalLMOutputWithPast(
                loss=None,
                logits=logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
        if contrastive_grad:
            if labels is not None:
                # Create masks and inputs for both cases
                index = labels == IMAGE_TOKEN_PATCH
                # print('length of labels[index]: ',len(labels[index]))
                attention_mask_wo_img = attention_mask.clone()
                attention_mask_wo_img[index] = False
                inputs_embeds_wo_img = inputs_embeds.clone()
                inputs_embeds_wo_img = inputs_embeds_wo_img.contiguous()
                
                
                
                labels[index] = IGNORE_INDEX  # Modify labels for the no-image tokens case

                # Concatenate inputs for the two scenarios along the batch dimension
                combined_attention_mask = torch.cat([attention_mask_wo_img, attention_mask], dim=0)
                combined_inputs_embeds = torch.cat([inputs_embeds_wo_img, inputs_embeds], dim=0)

                # Forward pass
                outputs = self.model(
                    input_ids=None,  # Assuming inputs_embeds is used
                    attention_mask=combined_attention_mask,
                    position_ids=position_ids.repeat(2, 1) if position_ids is not None else None,  # Adjust position_ids for doubled batch size
                    past_key_values=past_key_values,
                    inputs_embeds=combined_inputs_embeds,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    output_attention_statistics=output_attention_statistics, # TODO: add new argument
                    labels=labels_0, # TODO: add new argument
                )

                # Split hidden states for the two scenarios
                hidden_states_wo_img, hidden_states = torch.chunk(outputs[0], 2, dim=0)

                # Compute logits for each scenario
                if self.pretraining_tp > 1:
                    lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
                    logits_wo_img = [F.linear(hidden_states_wo_img, lm_head_slices[i]) for i in range(self.pretraining_tp)]
                    logits_wo_img = torch.cat(logits_wo_img, dim=-1)

                    logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
                    logits = torch.cat(logits, dim=-1)
                else:
                    logits_wo_img = self.lm_head(hidden_states_wo_img)
                    logits = self.lm_head(hidden_states)

                logits_wo_img = logits_wo_img.float()
                logits = logits.float()

        else:
            index = labels == IMAGE_TOKEN_PATCH         
            labels[index] = IGNORE_INDEX  # Modify labels for the no-image tokens case
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                output_attention_statistics=output_attention_statistics, 
                labels=labels_0, 
            )

            hidden_states = outputs[0]
            if self.pretraining_tp > 1:
                lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
                logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
                logits = torch.cat(logits, dim=-1)

            else:
                logits = self.lm_head(hidden_states)

            logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            labels_shape=shift_labels.shape
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
            
            
            if contrastive_grad:
                shift_labels_c = shift_labels.clone()
                shift_labels_c[shift_labels_c==IGNORE_INDEX] = 0
                shift_logits_w_img = shift_logits.clone()
                
                shift_logits_wo_img = logits_wo_img[..., :-1, :].contiguous()
                if not contrastive_logits:
                    del logits_wo_img
                     
                ### softmax
                shift_logits_w_img = torch.gather(shift_logits_w_img.view(shift_logits_wo_img.shape).log_softmax(-1), 2, shift_labels_c.view(labels_shape).unsqueeze(2)).squeeze(2) 
                shift_logits_wo_img = torch.gather(shift_logits_wo_img.log_softmax(-1), 2, shift_labels_c.view(labels_shape).unsqueeze(2).to(shift_logits_wo_img.device)).squeeze(2)
                sub = shift_logits_w_img - shift_logits_wo_img

                sub=sub.view(labels_shape)
                del shift_logits_wo_img, shift_labels_c,shift_logits_w_img
                
                
            

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        
        if contrastive_grad and (self.training and labels is not None):
            return CausalLMOutputWithPastAddWeight(
                loss=loss,
                logits=logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
                weight=(sub,uncertainty,shift_labels.view(labels_shape),logits_wo_img if contrastive_logits else None,labels_0)
            )
        else:
            return CausalLMOutputWithPast(
                loss=loss,
                logits=logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")

        if images is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                images,
                image_sizes=image_sizes
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)

        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        image_sizes = kwargs.pop("image_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            inputs['images'] = images
        if image_sizes is not None:
            inputs['image_sizes'] = image_sizes
        return inputs

AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
