#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn

from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig

from torch.nn import CrossEntropyLoss

from transformers import LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.generation.utils import GenerateOutput

from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from .sampler import GumbelSampler


class LlavaConfig(LlamaConfig):
    model_type = "llava_llama"
    temperature: float = 0.0  # reset to 0.0, previously 0.9 for Vicuna
    max_new_tokens: int = 1024
    do_sample: bool = False
    top_p: Optional[float] = None
    # rope_scaling: Optional[dict] = {}


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModel, self).__init__(config)


class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        LlamaForCausalLM.__init__(self, config)

        # configure default generation settings
        config.model_type = "llava_llama"
        # config.rope_scaling = None

        self.model = LlavaLlamaModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # Initialize weights and apply final processing
        self.post_init()

    def load_sampler(self, encoder_layers, config_pretrained_path, config_cache_dir, sampler_weights, 
                     sampler_type=None, iva_factor=None, th=None):
        self.sampler = GumbelSampler(encoder_layers, config_pretrained_path, config_cache_dir, 
                                     sampler_type, iva_factor, th).to(self.device)
        self.sampler.load_state_dict(torch.load(sampler_weights, weights_only=True), strict=False)
        self.sampler.eval()
        print(f'GumbelSampler[{sampler_type}] is loaded.')

    def get_model(self):
        return self.model

    # def forward(
    #     self,
    #     input_ids: torch.LongTensor = None,
    #     attention_mask: Optional[torch.Tensor] = None,
    #     position_ids: Optional[torch.LongTensor] = None,
    #     past_key_values: Optional[List[torch.FloatTensor]] = None,
    #     inputs_embeds: Optional[torch.FloatTensor] = None,
    #     labels: Optional[torch.LongTensor] = None,
    #     use_cache: Optional[bool] = None,
    #     output_attentions: Optional[bool] = None,
    #     output_hidden_states: Optional[bool] = None,
    #     images: Optional[torch.FloatTensor] = None,
    #     image_sizes: Optional[List[List[int]]] = None,
    #     return_dict: Optional[bool] = None,
    #     modalities: Optional[List[str]] = ["image"],
    #     dpo_forward: Optional[bool] = None,
    #     cache_position=None,
    # ) -> Union[Tuple, CausalLMOutputWithPast]:

    #     if inputs_embeds is None:
    #         (
    #             input_ids, 
    #             position_ids, 
    #             attention_mask, 
    #             past_key_values, 
    #             inputs_embeds, 
    #             labels
    #         ) = self.prepare_inputs_labels_for_multimodal(
    #             input_ids, 
    #             position_ids, 
    #             attention_mask, 
    #             past_key_values, 
    #             labels, 
    #             images, 
    #             modalities, 
    #             image_sizes
    #         )

    #     if dpo_forward:
    #         outputs = self.model(
    #             input_ids=input_ids,
    #             attention_mask=attention_mask,
    #             position_ids=position_ids,
    #             past_key_values=past_key_values,
    #             inputs_embeds=inputs_embeds,
    #             use_cache=use_cache,
    #             output_attentions=output_attentions,
    #             output_hidden_states=output_hidden_states,
    #             return_dict=return_dict,
    #         )

    #         hidden_states = outputs[0]
    #         logits = self.lm_head(hidden_states)
    #         return logits, labels

    #     else:
    #         return super().forward(
    #             input_ids=input_ids,
    #             attention_mask=attention_mask,
    #             position_ids=position_ids,
    #             past_key_values=past_key_values,
    #             inputs_embeds=inputs_embeds,
    #             labels=labels,
    #             use_cache=use_cache,
    #             output_attentions=output_attentions,
    #             output_hidden_states=output_hidden_states,
    #             return_dict=return_dict,
    #         )

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        modalities: Optional[List[str]] = ["image"],
        keep_small_image: Optional[bool] = None,
        calculate_all_tokens_number_with_fields: bool = False,
        draw_image: bool = False,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        modalities = kwargs.pop("modalities", None) if "modalities" in kwargs and modalities is None else modalities
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")

        if images is not None:
            (
                inputs, 
                position_ids, 
                attention_mask, 
                _, 
                inputs_embeds, 
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs, 
                position_ids, 
                attention_mask, 
                None, 
                None, 
                images, 
                modalities, 
                image_sizes=image_sizes,
                keep_small_image=keep_small_image,
                calculate_all_tokens_number_with_fields=calculate_all_tokens_number_with_fields,
                draw_image=draw_image
            )

        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)

        return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)

    # def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
    #     images = kwargs.pop("images", None)
    #     image_sizes = kwargs.pop("image_sizes", None)
    #     inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
    #     if images is not None:
    #         inputs["images"] = images
    #     if image_sizes is not None:
    #         inputs["image_sizes"] = image_sizes
    #     return inputs

    def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        model_kwargs: Dict[str, Any],
        is_encoder_decoder: bool = False,
        num_new_tokens: int = 1,
    ) -> Dict[str, Any]:
        model_kwargs = super()._update_model_kwargs_for_generation(outputs,
                                                                   model_kwargs,
                                                                   is_encoder_decoder,
                                                                   num_new_tokens)
        
        # model_kwargs['position_ids'] = torch.cat([model_kwargs['position_ids'],  model_kwargs['position_ids'][:, -1:] + 1], dim=-1)
        if model_kwargs['position_ids'] is not None:
            model_kwargs['position_ids'] = model_kwargs['position_ids'][:, -1:] + 1
        return model_kwargs

AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
