#    Copyright 2024 Hao Zhang
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
import copy
from typing import List, Optional, Tuple, Union, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM

from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLMWithRet

from dataclasses import dataclass
import torch.distributed as dist
from common_utils.dist_utils import (
    all_gather_with_grad as all_gather_with_grad_lavis,
    all_gather_with_grad_torch
)
from common_utils.loss_utils import contrastive_loss, contrastive_loss_with_target

@dataclass
class CausalLMOutputWithPastWithRet(CausalLMOutputWithPast):
    # these losses are for gathering and stats only
    inter_cont_loss: Optional[torch.FloatTensor] = None
    intra_cont_loss: Optional[torch.FloatTensor] = None
    query_feature: torch.FloatTensor = None
    inter_doc_feature: torch.FloatTensor = None
    intra_doc_feature: torch.FloatTensor = None
    num_valid_sections: Optional[torch.LongTensor] = None


class LlavaQwenConfig(Qwen2Config):
    model_type = "llava_qwen"


class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
    config_class = LlavaQwenConfig

    def __init__(self, config: Qwen2Config):
        super(LlavaQwenModel, self).__init__(config)


class LlavaQwenForCausalLMWithDualRet(Qwen2ForCausalLM, LlavaMetaForCausalLMWithRet):
    config_class = LlavaQwenConfig

    def __init__(self, config):
        super(Qwen2ForCausalLM, self).__init__(config)
        # Qwen2ForCausalLM.__init__(self, config)
        config.model_type = "llava_qwen"
        config.rope_scaling = None

        self.model = LlavaQwenModel(config)  # model for query encoder (named it as 'model' to inherit the weight from the LLaVA-Next-Interleave)
        self.doc_model = LlavaQwenModel(config)  # doc_model for document encoder
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

        # For learnable End of Query and End of Section
        # The first prompt is for the query and the second prompt is for the section.
        self.soft_prompt = nn.Embedding(8, self.config.hidden_size)
        init_prompt_value = torch.FloatTensor(8, self.config.hidden_size).uniform_(-0.5, 0.5).to(self.soft_prompt.weight.dtype)
        self.soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)

    def get_model(self, model_name='query'):
        if model_name == 'query':
            return self.model
        else:
            return self.doc_model

    def copy_model(self):
        self.doc_model = copy.deepcopy(self.model)

    def forward(
        self,
        # We keep the input_ids, attention_mask ... formats to meet the requirement of Trainer of the transformer module.
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        # query inputs
        query_input_ids: torch.LongTensor = None,
        query_attention_mask: Optional[torch.Tensor] = None,
        query_inputs_embeds: Optional[torch.FloatTensor] = None,
        query_images: Optional[torch.FloatTensor] = None,
        query_evidence_section_labels: Optional[torch.LongTensor] = None,
        # document inputs
        doc_input_ids: torch.LongTensor = None,
        doc_attention_mask: Optional[torch.Tensor] = None,
        doc_inputs_embeds: Optional[torch.FloatTensor] = None,
        doc_images: Optional[torch.FloatTensor] = None,
        doc_num_sections: Optional[torch.LongTensor] = None,
        # general inputs
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position=None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if query_input_ids is not None:
            
            (query_input_ids, query_position_ids, query_attention_mask, query_past_key_values,
             query_inputs_embeds, query_tok_indices, _) = (
                self.prepare_inputs_for_multimodal(query_input_ids, position_ids, query_attention_mask,
                                                          past_key_values, query_images,
                                                          is_query=True, is_doc=False))
            query_outputs = self.model(
                input_ids=query_input_ids,
                attention_mask=query_attention_mask,
                position_ids=query_position_ids,
                past_key_values=query_past_key_values,
                inputs_embeds=query_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            # Extract End of Query tokens
            query_hidden_states = query_outputs[0]  # B, L, D
            query_tok_features = []
            for i in range(len(query_hidden_states)):
                query_tok_features.append(query_hidden_states[i, query_tok_indices[i]].mean(dim=0))
            query_tok_features = torch.stack(query_tok_features, dim=0)  # B x D
        else:
            query_outputs = None
            query_tok_features = None


        # Prepare document inputs
        if doc_input_ids is not None:
            (doc_input_ids, doc_position_ids, doc_attention_mask, doc_past_key_values,
             doc_inputs_embeds, doc_tok_indices, section_group_indices) = (
                self.prepare_inputs_for_multimodal(doc_input_ids, position_ids, doc_attention_mask,
                                                          past_key_values, doc_images,
                                                          is_query=False, is_doc=True))
            doc_outputs = self.doc_model(
                input_ids=doc_input_ids,
                attention_mask=doc_attention_mask,
                position_ids=doc_position_ids,
                past_key_values=doc_past_key_values,
                inputs_embeds=doc_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            doc_hidden_states = doc_outputs[0]  # B'(all sections), L, D

            # Extract End of Section tokens
            intra_pas_tok_features = []
            for i in range(len(doc_hidden_states)):
                intra_pas_tok_features.append(doc_hidden_states[i, doc_tok_indices[i]].mean(dim=0))
            intra_pas_tok_features = torch.stack(intra_pas_tok_features, dim=0)

            # Merge split passage features into section features
            intra_doc_tok_features = []
            for section_idx in section_group_indices:
                if len(section_idx) == 1:
                    intra_doc_tok_features.append(intra_pas_tok_features[section_idx])
                else:
                    intra_doc_tok_features.append(torch.mean(intra_pas_tok_features[section_idx], dim=0, keepdim=True))
            intra_doc_tok_features = torch.cat(intra_doc_tok_features, dim=0) # B' x D, B': the number of sections in the batch

            # Document features
            inter_doc_tok_features = []
            accum_sections = 0
            for i in range(len(doc_num_sections)):
                if doc_num_sections[i] == 1:
                    inter_doc_tok_feat = intra_doc_tok_features[accum_sections:accum_sections+doc_num_sections[i]]
                else:
                    inter_doc_tok_feat = torch.mean(intra_doc_tok_features[accum_sections:accum_sections+doc_num_sections[i]], dim=0, keepdim=True)
                inter_doc_tok_features.append(inter_doc_tok_feat)
                accum_sections += doc_num_sections[i].item()
            inter_doc_tok_features = torch.cat(inter_doc_tok_features, dim=0)  # B x D
        else:
            doc_outputs = None
            inter_doc_tok_features = None
            intra_doc_tok_features = None

        loss = None
        inter_cont_loss = None
        intra_cont_loss = None
        n_query_tok_features = None

        # Perform inter-document contrastive loss.
        if self.config.inter_contrastive:
            assert query_tok_features is not None and inter_doc_tok_features is not None

            if dist.is_initialized():
                query_tok_features = all_gather_with_grad_lavis(query_tok_features)
                inter_doc_tok_features = all_gather_with_grad_lavis(inter_doc_tok_features)

            # Normalize features of NLL loss
            n_query_tok_features = F.normalize(query_tok_features, dim=-1)
            n_inter_doc_tok_features = F.normalize(inter_doc_tok_features, dim=-1)
            # q2d contrastive loss
            q2d_inter_logits = n_query_tok_features @ n_inter_doc_tok_features.t()

            inter_cont_loss = contrastive_loss(q2d_inter_logits)
            loss = inter_cont_loss

        # Perform intra-document contrastive loss.
        elif self.config.intra_contrastive:
            assert query_tok_features is not None and intra_doc_tok_features is not None

            # Make section evidence target labels
            intra_doc_evidence_ids = []
            count_sec_id = 0
            for i in range(len(doc_num_sections)):
                num_secs = doc_num_sections[i]
                intra_doc_evidence_ids.append(count_sec_id + query_evidence_section_labels[i])
                count_sec_id += num_secs
            intra_doc_evidence_ids = torch.tensor(intra_doc_evidence_ids, device=intra_doc_tok_features.device) # Number of all-sections of B samples

            if dist.is_initialized():
                query_tok_features = all_gather_with_grad_lavis(query_tok_features)
                intra_doc_tok_features = all_gather_with_grad_lavis(intra_doc_tok_features)
                intra_doc_evidence_ids = all_gather_with_grad_lavis(intra_doc_evidence_ids)

            # Normalize features of NLL loss
            n_query_tok_features = F.normalize(query_tok_features, dim=-1)
            n_intra_doc_tok_features = F.normalize(intra_doc_tok_features, dim=-1)
            # q2d contrastive loss
            q2d_intra_logits = n_query_tok_features @ n_intra_doc_tok_features.t()
            assert doc_num_sections.sum() == len(intra_doc_tok_features), f"Here is the bug"
            intra_cont_loss = contrastive_loss_with_target(q2d_intra_logits, target=intra_doc_evidence_ids)
            loss = intra_cont_loss

        if not return_dict:
            output = []
            if query_outputs is not None:
                output.append(((query_tok_features,) + query_outputs[1:]))
            if doc_outputs is not None:
                output.append(((inter_doc_tok_features, intra_doc_tok_features, ) + doc_outputs[1:]))

            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPastWithRet(
            loss=loss,
            query_feature=query_tok_features,
            inter_doc_feature=inter_doc_tok_features,
            intra_doc_feature=intra_doc_tok_features,
            inter_cont_loss=inter_cont_loss,
            intra_cont_loss=intra_cont_loss,
            num_valid_sections=doc_num_sections,
        )


AutoConfig.register("llava_qwen", LlavaQwenConfig)
AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLMWithDualRet)
