
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM, AutoTokenizer

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from ...mm_utils import tokenizer_image_token
from ...constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
from dataclasses import dataclass


class LlavaConfig(LlamaConfig):
    model_type = "llava"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModel, self).__init__(config)


@dataclass
class LLaVAOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    labels: torch.IntTensor = None
    attention_mask: torch.IntTensor = None


class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model
    
    def prepare_inputs_from_batch(self, samples):
        text = [t for t in samples["text_input"]]
        tokenizer = AutoTokenizer.from_pretrained('hugging_cache/llava-v1.5-7b', use_fast=False)

        input_tokens = tokenizer(text, padding=True, return_tensors='pt').to(self.device)
        input_ids = input_tokens.input_ids
        attention_mask = input_tokens.attention_mask
        
        if samples['image'] is not None:
            image = samples["image"].to(self.dtype) # bsz, 3, image_size, image_size
            image_token_ids = torch.ones((input_ids.shape[0]), dtype=input_ids.dtype, device=self.device).fill_(IMAGE_TOKEN_INDEX)
            input_ids = torch.cat((input_ids[:, :1], image_token_ids.unsqueeze(1), input_ids[:, 1:]), dim=1)

            image_att_mask = torch.ones((input_ids.shape[0]), dtype=input_ids.dtype, device=self.device)
            attention_mask = torch.cat((attention_mask[:, :1], image_att_mask.unsqueeze(1), attention_mask[:, 1:]), dim=1)
            
            targets = input_ids.masked_fill(input_ids==tokenizer.pad_token_id, IGNORE_INDEX)
            if samples['prompts_len']:
                for i, prompt_len in enumerate(samples['prompts_len']):
                    targets[i, :prompt_len+1] = IGNORE_INDEX
        else:
            image = None
            targets = input_ids.masked_fill(input_ids==tokenizer.pad_token_id, IGNORE_INDEX)
            if samples['prompts_len']:
                for i, prompt_len in enumerate(samples['prompts_len']):
                    targets[i, :prompt_len] = IGNORE_INDEX

        return self.prepare_inputs_labels_for_multimodal(
            input_ids,
            None,
            attention_mask,
            None,
            targets,
            image
        )
    
    def forward(self, samples):
        
        (   input_ids,
            _,
            attention_mask,
            _,
            inputs_embeds,
            targets
        ) = self.prepare_inputs_from_batch(samples)

        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=None,
            past_key_values=None,
            inputs_embeds=inputs_embeds,
            labels=targets,
            return_dict=True
        )

        if torch.isnan(outputs.logits).any():
            print("LLaVA logits has nan!!!!!!!!!!!!!!!")

        return LLaVAOutput(
            loss=outputs.loss,
            logits=outputs.logits,
            labels=targets,
            attention_mask=attention_mask
        )

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        _inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            _inputs['images'] = images
        return _inputs
