#    Copyright 2024 Hao Zhang
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
import copy
from typing import List, Optional, Tuple, Union, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM

from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLMWithReRank

from dataclasses import dataclass

@dataclass
class CausalLMOutputWithPastWithReRank(CausalLMOutputWithPast):
    # these losses are for gathering and stats only
    loss: Optional[torch.FloatTensor] = None
    predictions: Optional[torch.FloatTensor] = None
    num_valid_sections: Optional[torch.LongTensor] = None


class LlavaQwenConfig(Qwen2Config):
    model_type = "llava_qwen"


class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
    config_class = LlavaQwenConfig

    def __init__(self, config: Qwen2Config):
        super(LlavaQwenModel, self).__init__(config)


class LlavaQwenForCausalLMWithDualReRank(Qwen2ForCausalLM, LlavaMetaForCausalLMWithReRank):
    config_class = LlavaQwenConfig

    def __init__(self, config):
        super(Qwen2ForCausalLM, self).__init__(config)
        # Qwen2ForCausalLM.__init__(self, config)
        config.model_type = "llava_qwen"
        config.rope_scaling = None

        self.model = LlavaQwenModel(config)  # model for query encoder (named it as 'model' to inherit the weight from the LLaVA-Next-Interleave)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Linear projection layer projecting feature to a scalar.
        self.scoring_head = nn.Linear(config.hidden_size, 1, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

        # For learnable End of Section
        self.soft_prompt = nn.Embedding(4, self.config.hidden_size)
        init_prompt_value = torch.FloatTensor(4, self.config.hidden_size).uniform_(-0.5, 0.5).to(self.soft_prompt.weight.dtype)
        self.soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)

    def get_model(self):
        return self.model

    def forward(
        self,
        # We keep the input_ids, attention_mask ... formats to meet the requirement of Trainer of the transformer module.
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        # query inputs
        query_input_ids: torch.LongTensor = None,
        query_attention_mask: Optional[torch.Tensor] = None,
        query_inputs_embeds: Optional[torch.FloatTensor] = None,
        query_images: Optional[torch.FloatTensor] = None,
        query_evidence_section_labels: Optional[torch.LongTensor] = None,
        # document inputs
        doc_input_ids: torch.LongTensor = None,
        doc_attention_mask: Optional[torch.Tensor] = None,
        doc_inputs_embeds: Optional[torch.FloatTensor] = None,
        doc_images: Optional[torch.FloatTensor] = None,
        doc_num_sections: Optional[torch.LongTensor] = None,
        # general inputs
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position=None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        assert query_input_ids is not None and doc_input_ids is not None

        # Prepare re-ranking by concatenating query and sections.
        # query - section # 1
        # query - section # 2
        # ...
        # query - section # N
        input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, tok_indices, section_group_indices = (
            self.prepare_inputs_for_multimodal(query_input_ids, query_attention_mask, query_images, 
                                               doc_input_ids, doc_attention_mask, doc_images, doc_num_sections, position_ids, past_key_values))

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )

        # Extract End of Section tokens
        hidden_states = outputs[0]  # B'(all sections), L, D
        pas_tok_features = []
        for i in range(len(hidden_states)):
            pas_tok_features.append(hidden_states[i, tok_indices[i]].mean(dim=0))
        pas_tok_features = torch.stack(pas_tok_features, dim=0)

        # Merge split passage features into section features
        sec_tok_features = []
        for section_idx in section_group_indices:
            if len(section_idx) == 1:
                sec_tok_features.append(pas_tok_features[section_idx])
            else:
                sec_tok_features.append(torch.mean(pas_tok_features[section_idx], dim=0, keepdim=True))
        sec_tok_features = torch.cat(sec_tok_features, dim=0) # B' x D, B': the number of sections in the batch

        # Map the features to scalar values
        logits = self.scoring_head(sec_tok_features)
        predictions = torch.sigmoid(logits)

        if query_evidence_section_labels is not None:
            # Get 0 / 1 label for BCE loss
            # Make section evidence target labels
            count_sec_id = 0
            sec_evidence_ids = []
            for i in range(len(doc_num_sections)):
                num_secs = doc_num_sections[i]
                sec_evidence_ids.append(count_sec_id + query_evidence_section_labels[i])
                count_sec_id += num_secs
            sec_evidence_ids = torch.tensor(sec_evidence_ids, dtype=torch.int64, device=sec_tok_features.device)

            labels = torch.zeros((len(sec_tok_features),1), dtype=sec_tok_features.dtype, device=sec_tok_features.device)
            labels[sec_evidence_ids] = 1
            loss = nn.BCEWithLogitsLoss()(logits, labels)
        else:
            loss = None

        if not return_dict:
            return (loss,) + outputs if loss is not None else outputs

        return CausalLMOutputWithPastWithReRank(
            loss=loss,
            predictions=predictions,
            num_valid_sections=doc_num_sections,
        )


AutoConfig.register("llava_qwen", LlavaQwenConfig)
AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLMWithDualReRank)
