from dataclasses import dataclass
from typing import Optional, List, Union, Tuple
import torch
import wandb
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import LlamaPreTrainedModel, LlamaModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import ModelOutput
from transformers.cache_utils import Cache

# class ContrastiveLoss(nn.Module):
#     def __init__(self, temperature):
#         super().__init__()
#         self.temperature = 1 if temperature is None else temperature
#         self.cosine = nn.CosineSimilarity(dim=-1)

#     def forward(self, original_instruction_inputs, paraphrased_instruction_inputs):
#         batch_size = original_instruction_inputs.size(0)
#         labels = torch.arange(batch_size).to(original_instruction_inputs.device)

#         original_to_paraphrased_sim = F.cosine_similarity(original_instruction_inputs.unsqueeze(1),
#                                                           paraphrased_instruction_inputs.unsqueeze(0),
#                                                           dim=2) / self.temperature
#         paraphrased_to_original_sim = original_to_paraphrased_sim.T

#         ori_to_para_loss = F.cross_entropy(original_to_paraphrased_sim, labels)
#         para_to_ori_loss = F.cross_entropy(paraphrased_to_original_sim, labels)

#         loss = (ori_to_para_loss + para_to_ori_loss) / 2

#         return loss, original_to_paraphrased_sim, paraphrased_to_original_sim, ori_to_para_loss, para_to_ori_loss
    
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature):
        super().__init__()
        self.temperature = 1 if temperature is None else temperature
        self.cosine = nn.CosineSimilarity(dim=-1)

    def forward(self, original_instruction_inputs, positive_inputs, negative_inputs):
        # 计算原始指令和正样本之间的相似度
        pos_similarity = F.cosine_similarity(original_instruction_inputs, positive_inputs) / self.temperature
        
        # 计算原始指令和每个负样本之间的相似度
        neg_similarity = F.cosine_similarity(
            original_instruction_inputs.unsqueeze(1), negative_inputs) / self.temperature

        # 对负样本相似度进行softmax处理，以便转换为概率分布
        neg_similarity = neg_similarity.softmax(dim=-1)

        # 计算损失：最大化正样本相似度，最小化负样本相似度
        # 注意：实际实现时可能需要调整此损失函数以满足特定的任务需求
        loss = -torch.log(pos_similarity / (pos_similarity + neg_similarity.sum(dim=-1)))

        return loss.mean()  # 返回批次的平均损失

@dataclass
class ContrastiveCausalLMOutputWithPast(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    cnt_loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None

## test
class LlamaForCausalLM_Constrative(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # contrastive
        self.tau = 0.01
        self.pooling_method = 'last'
        self.contrastive_loss = ContrastiveLoss(self.tau)
        self.contrastive_loss_ratio = 1000
        # Initialize weights and apply final processing
        # self.ablation = config.ablation
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def get_decoder_outputs(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            labels: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        return outputs
    
    def separate_batch_prompts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor, batch_size: int):
        """
        Separate merged pair of prompts to original batch & paraphrased batch
        """
        original_tokenized_full_prompt = {
            "input_ids": torch.squeeze(input_ids[:batch_size], dim=1).long(),
            "attention_mask": torch.squeeze(attention_mask[:batch_size], dim=1),
            "labels": torch.squeeze(labels[:batch_size], dim=1).long()
        }

        paraphrased_tokenized_full_prompt = {
            "input_ids": torch.squeeze(input_ids[batch_size:], dim=1).long(),
            "attention_mask": torch.squeeze(attention_mask[batch_size:], dim=1),
            "labels": torch.squeeze(labels[batch_size:], dim=1).long()
        }

        return original_tokenized_full_prompt, paraphrased_tokenized_full_prompt

    # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    # @replace_return_docstrings(output_type=ContrastiveCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        input_ids_pos: torch.LongTensor = None,
        attention_mask_pos: Optional[torch.Tensor] = None,
        input_ids_neg: torch.LongTensor = None,
        attention_mask_neg: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        labels_pos: Optional[torch.LongTensor] = None,
        labels_neg: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, ContrastiveCausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # batch_ori, batch_neg = self.separate_batch_prompts(
        #                         input_ids,
        #                         attention_mask,
        #                         labels,
        #                         int(input_ids.size(0) / 2))
        batch_ori = {
            "input_ids":input_ids,
            "attention_mask":attention_mask,
            "labels":labels
        }


        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.get_decoder_outputs(
            **batch_ori
        )

        hidden_states = outputs[0] # batch, length, hidden
        logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        cnt_loss = None
        if labels is not None: # batch, length
            lm_loss = self.get_entropy_loss_for_token_prediction(logits,batch_ori['labels'])
            batch_pos = {
                "input_ids":input_ids_pos,
                "attention_mask":attention_mask_pos,
                "labels":labels_pos
            }
            batch_neg = {
                "input_ids":input_ids_neg,
                "attention_mask":attention_mask_neg,
                "labels":labels_neg
            }
            # Contrastive loss implementation 2
            outputs_pos = self.get_decoder_outputs(
                **batch_pos
            )
            outputs_neg = self.get_decoder_outputs(
                **batch_neg
            )
            # outputs_neg = outputs
            hidden_states_neg = outputs_neg[0]
            logits_neg = self.lm_head(hidden_states_neg)
            logits_neg = logits_neg.float()
            hidden_states_pos = outputs_pos[0]

            # lm_loss_neg = self.get_entropy_loss_for_token_prediction(logits_neg,labels_neg)
            generation_loss = lm_loss

            cnt_loss= self.contrastive_loss(
                    self.get_pooled_hidden_states(hidden_states),
                    self.get_pooled_hidden_states(hidden_states_pos),
                    self.get_pooled_hidden_states(hidden_states_neg).unsqueeze(1)
                )
            cnt_loss = cnt_loss * self.contrastive_loss_ratio
            cnt_loss = self.scale_contrastive_loss(generation_loss, cnt_loss, self.contrastive_loss_ratio)
            loss = generation_loss

            if torch.isnan(cnt_loss):
                cnt_loss = torch.tensor(0.0).to(lm_loss.device)
        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output
        return ContrastiveCausalLMOutputWithPast(
            loss=loss,
            cnt_loss=cnt_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=None,
            attentions=outputs.attentions,
        )
    
    def get_entropy_loss_for_token_prediction(self, logits, labels):
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
        return loss

    def scale_contrastive_loss(self, generation_loss, contrastive_loss, max_scale_ratio):
            if contrastive_loss != 0 and contrastive_loss > generation_loss:
                new_contrastive_loss = contrastive_loss * (
                    min(max_scale_ratio, generation_loss.detach() / contrastive_loss.detach()))
            else:
                new_contrastive_loss = contrastive_loss
            return new_contrastive_loss

    def avg_pool(self, hidden_states, mask):
        length = torch.sum(mask, 1, keepdim=True).float()
        mask = mask.unsqueeze(2).contiguous()
        hidden = hidden_states.masked_fill(mask == 0, 0.0)
        avg_hidden = torch.sum(hidden, 1) / length
        # [batch_size, hidden_dim]
        return avg_hidden

    def get_pooled_hidden_states(self, hidden_states):
        """
        Get hidden states of the last token of each sequence (reference: LlamaForSequenceClassification)
        hidden_states: (batch_size, seq_length, vocab_num)
        return: (batch_size, vocab_num)
        """
        if self.pooling_method == 'last':
            return hidden_states[torch.arange(hidden_states.size(0), device=hidden_states.device), -1]
        elif 'average' in self.pooling_method:
            if self.pooling_method == 'average_first_last':
                hidden_states = torch.cat((hidden_states[:, 0], hidden_states[:, -1])).unsqueeze(0)
            if self.pooling_method == 'average_first_last' or self.pooling_method == 'average_all':
                return torch.mean(hidden_states, dim=1)
            else:
                raise ValueError(f"Pooling method {self.self.pooling_method} not supported")
        elif self.pooling_method == 'max':
            return torch.max(hidden_states, dim=1).values
        else:
            raise ValueError(f"Pooling method {self.pooling_metlora_rhod} not supported")

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
    ):
        # With static cache, the `past_key_values` is None
        # TODO joao: standardize interface for the different Cache classes and remove of this if
        has_static_cache = False
        if past_key_values is None:
            past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
            has_static_cache = past_key_values is not None

        past_length = 0
        if past_key_values is not None:
            if isinstance(past_key_values, Cache):
                past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
                max_cache_length = (
                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
                    if past_key_values.get_max_length() is not None
                    else None
                )
                cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
            else:
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
            # input)
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
            # TODO: use `next_tokens` directly instead.
            model_inputs = {"input_ids": input_ids.contiguous()}

        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
        if cache_position is None:
            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
        else:
            cache_position = cache_position[-input_length:]

        if has_static_cache:
            past_key_values = None

        model_inputs.update(
            {
                "position_ids": position_ids,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past


class LlamaForCausalLM_Constrative2(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # contrastive
        self.tau = 0.0001
        self.pooling_method = 'last'
        self.contrastive_loss = ContrastiveLoss(self.tau)
        self.contrastive_loss_ratio = 1000
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def get_decoder_outputs(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            labels: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        return outputs
    
    def separate_batch_prompts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor, batch_size: int):
        """
        Separate merged pair of prompts to original batch & paraphrased batch
        """
        original_tokenized_full_prompt = {
            "input_ids": torch.squeeze(input_ids[:batch_size], dim=1).long(),
            "attention_mask": torch.squeeze(attention_mask[:batch_size], dim=1),
            "labels": torch.squeeze(labels[:batch_size], dim=1).long()
        }

        paraphrased_tokenized_full_prompt = {
            "input_ids": torch.squeeze(input_ids[batch_size:], dim=1).long(),
            "attention_mask": torch.squeeze(attention_mask[batch_size:], dim=1),
            "labels": torch.squeeze(labels[batch_size:], dim=1).long()
        }

        return original_tokenized_full_prompt, paraphrased_tokenized_full_prompt

    # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    # @replace_return_docstrings(output_type=ContrastiveCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        input_ids_pos_1: torch.LongTensor = None,
        attention_mask_pos_1: Optional[torch.Tensor] = None,
        input_ids_neg_1: torch.LongTensor = None,
        attention_mask_neg_1: Optional[torch.Tensor] = None,
        input_ids_pos_2: torch.LongTensor = None,
        attention_mask_pos_2: Optional[torch.Tensor] = None,
        input_ids_neg_2: torch.LongTensor = None,
        attention_mask_neg_2: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, ContrastiveCausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # batch_ori, batch_neg = self.separate_batch_prompts(
        #                         input_ids,
        #                         attention_mask,
        #                         labels,
        #                         int(input_ids.size(0) / 2))
        batch_ori = {
            "input_ids":input_ids,
            "attention_mask":attention_mask,
            "labels":labels
        }


        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.get_decoder_outputs(
            **batch_ori
        )

        hidden_states = outputs[0] # batch, length, hidden
        logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        cnt_loss = None
        if labels is not None: # batch, length
            lm_loss = self.get_entropy_loss_for_token_prediction(logits,batch_ori['labels'])

            batch_combined_1 = {
                "input_ids": torch.cat([input_ids_pos_1, input_ids_neg_1], dim=0),
                "attention_mask": torch.cat([attention_mask_pos_1, attention_mask_neg_1], dim=0),
            }

            batch_combined_2 = {
                "input_ids": torch.cat([input_ids_pos_2, input_ids_neg_2], dim=0),
                "attention_mask": torch.cat([attention_mask_pos_2, attention_mask_neg_2], dim=0),
            }

            # 前向传播
            outputs_combined_1 = self.get_decoder_outputs(**batch_combined_1)
            outputs_combined_2 = self.get_decoder_outputs(**batch_combined_2)

            # 计算每个batch的大小
            batch_size = input_ids_pos_1.size(0)

            # 拆分输出
            hidden_states_pos_1 = outputs_combined_1[0][:batch_size]
            hidden_states_neg_1 = outputs_combined_1[0][batch_size:]
            hidden_states_pos_2 = outputs_combined_2[0][:batch_size]
            hidden_states_neg_2 = outputs_combined_2[0][batch_size:]

            generation_loss = lm_loss
            stack_hidden_neg = torch.stack((self.get_pooled_hidden_states(hidden_states_neg_1),self.get_pooled_hidden_states(hidden_states_neg_2)),dim=1)
            cnt_loss_1= self.contrastive_loss(
                    self.get_pooled_hidden_states(hidden_states),
                    self.get_pooled_hidden_states(hidden_states_pos_1),
                    stack_hidden_neg
                )
            cnt_loss_2= self.contrastive_loss(
                    self.get_pooled_hidden_states(hidden_states),
                    self.get_pooled_hidden_states(hidden_states_pos_2),
                    stack_hidden_neg
                )
            cnt_loss = cnt_loss_1 + cnt_loss_2
            # cnt_loss = cnt_loss * self.contrastive_loss_ratio
            # cnt_loss = self.scale_contrastive_loss(generation_loss, cnt_loss, self.contrastive_loss_ratio)
            loss = generation_loss


        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output
        return ContrastiveCausalLMOutputWithPast(
            loss=loss,
            cnt_loss=cnt_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=None,
            attentions=outputs.attentions,
        )
    
    def get_entropy_loss_for_token_prediction(self, logits, labels):
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
        return loss

    def scale_contrastive_loss(self, generation_loss, contrastive_loss, max_scale_ratio):
            if contrastive_loss != 0 and contrastive_loss > generation_loss:
                new_contrastive_loss = contrastive_loss * (
                    min(max_scale_ratio, generation_loss.detach() / contrastive_loss.detach()))
            else:
                new_contrastive_loss = contrastive_loss
            return new_contrastive_loss

    def avg_pool(self, hidden_states, mask):
        length = torch.sum(mask, 1, keepdim=True).float()
        mask = mask.unsqueeze(2).contiguous()
        hidden = hidden_states.masked_fill(mask == 0, 0.0)
        avg_hidden = torch.sum(hidden, 1) / length
        # [batch_size, hidden_dim]
        return avg_hidden

    def get_pooled_hidden_states(self, hidden_states):
        """
        Get hidden states of the last token of each sequence (reference: LlamaForSequenceClassification)
        hidden_states: (batch_size, seq_length, vocab_num)
        return: (batch_size, vocab_num)
        """
        if self.pooling_method == 'last':
            return hidden_states[torch.arange(hidden_states.size(0), device=hidden_states.device), -1]
        elif 'average' in self.pooling_method:
            if self.pooling_method == 'average_first_last':
                hidden_states = torch.cat((hidden_states[:, 0], hidden_states[:, -1])).unsqueeze(0)
            if self.pooling_method == 'average_first_last' or self.pooling_method == 'average_all':
                return torch.mean(hidden_states, dim=1)
            else:
                raise ValueError(f"Pooling method {self.self.pooling_method} not supported")
        elif self.pooling_method == 'max':
            return torch.max(hidden_states, dim=1).values
        else:
            raise ValueError(f"Pooling method {self.pooling_metlora_rhod} not supported")

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
    ):
        # With static cache, the `past_key_values` is None
        # TODO joao: standardize interface for the different Cache classes and remove of this if
        has_static_cache = False
        if past_key_values is None:
            past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
            has_static_cache = past_key_values is not None

        past_length = 0
        if past_key_values is not None:
            if isinstance(past_key_values, Cache):
                past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
                max_cache_length = (
                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
                    if past_key_values.get_max_length() is not None
                    else None
                )
                cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
            else:
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
            # input)
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
            # TODO: use `next_tokens` directly instead.
            model_inputs = {"input_ids": input_ids.contiguous()}

        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
        if cache_position is None:
            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
        else:
            cache_position = cache_position[-input_length:]

        if has_static_cache:
            past_key_values = None

        model_inputs.update(
            {
                "position_ids": position_ids,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past



class LlamaForCausalLM_Constrative3(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config, knowing_threshold=1.0):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # contrastive
        self.tau = 0.00001
        self.pooling_method = 'last'
        self.contrastive_loss = ContrastiveLoss(self.tau)
        self.contrastive_loss_ratio = 1000
        self.knowing_threshold=knowing_threshold
        # Initialize weights and apply final processing
        self.select_threshold = 0.7
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def get_decoder_outputs(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            labels: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        return outputs
    

    # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    # @replace_return_docstrings(output_type=ContrastiveCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        input_ids_pos_1: torch.LongTensor = None,
        attention_mask_pos_1: Optional[torch.Tensor] = None,
        input_ids_neg_1: torch.LongTensor = None,
        attention_mask_neg_1: Optional[torch.Tensor] = None,
        input_ids_pos_2: torch.LongTensor = None,
        attention_mask_pos_2: Optional[torch.Tensor] = None,
        input_ids_neg_2: torch.LongTensor = None,
        attention_mask_neg_2: Optional[torch.Tensor] = None,
        acc: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, ContrastiveCausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_ori = {
            "input_ids":input_ids,
            "attention_mask":attention_mask,
            "labels":labels
        }


        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.get_decoder_outputs(
            **batch_ori
        )

        hidden_states = outputs[0] # batch, length, hidden
        logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        cnt_loss = None
        if labels is not None: # batch, length
            lm_loss = self.get_entropy_loss_for_token_prediction(logits,batch_ori['labels'])
            
            batch_combined_1 = {
                "input_ids": torch.cat([input_ids_pos_1, input_ids_neg_1], dim=0),
                "attention_mask": torch.cat([attention_mask_pos_1, attention_mask_neg_1], dim=0),
            }

            batch_combined_2 = {
                "input_ids": torch.cat([input_ids_pos_2, input_ids_neg_2], dim=0),
                "attention_mask": torch.cat([attention_mask_pos_2, attention_mask_neg_2], dim=0),
            }

            # 前向传播
            outputs_combined_1 = self.get_decoder_outputs(**batch_combined_1)
            outputs_combined_2 = self.get_decoder_outputs(**batch_combined_2)

            # 计算每个batch的大小
            batch_size = input_ids_pos_1.size(0)

            # 拆分输出
            hidden_states_pos_1 = outputs_combined_1[0][:batch_size]
            hidden_states_neg_1 = outputs_combined_1[0][batch_size:]
            hidden_states_pos_2 = outputs_combined_2[0][:batch_size]
            hidden_states_neg_2 = outputs_combined_2[0][batch_size:]


            batch_size = input_ids_pos_1.size(0)
            for i in range(batch_size):
                _hidden_states = hidden_states[i].unsqueeze(0)
                _acc = acc[i][0].item()                
                _hidden_states_pos_1 = hidden_states_pos_1[i].unsqueeze(0)
                _hidden_states_neg_1 = hidden_states_neg_1[i].unsqueeze(0)
                _hidden_states_pos_2 = hidden_states_pos_2[i].unsqueeze(0)
                _hidden_states_neg_2 = hidden_states_neg_2[i].unsqueeze(0)
                if _acc >= self.knowing_threshold:  # this is for model known question
                    # _stack_hidden_neg = torch.stack((self.get_pooled_hidden_states(_hidden_states_neg_1),self.get_pooled_hidden_states(_hidden_states_neg_2)),dim=1)
                    _stack_hidden_neg = self.get_pooled_hidden_states(_hidden_states_neg_1).unsqueeze(1)
                    _cnt_loss = self.contrastive_loss(
                        self.get_pooled_hidden_states(_hidden_states),
                        self.get_pooled_hidden_states(_hidden_states_pos_1),
                        _stack_hidden_neg
                    )
                else: # this is for model unknown question, use
                    if _acc >= self.select_threshold: # use positive_1
                        _cnt_loss = self.contrastive_loss(
                            self.get_pooled_hidden_states(_hidden_states),
                            self.get_pooled_hidden_states(_hidden_states_pos_1),
                            self.get_pooled_hidden_states(_hidden_states_neg_1).unsqueeze(1)
                        )
                    else: # use positive_2
                        _cnt_loss = self.contrastive_loss(
                            self.get_pooled_hidden_states(_hidden_states),
                            self.get_pooled_hidden_states(_hidden_states_pos_2),
                            self.get_pooled_hidden_states(_hidden_states_neg_1).unsqueeze(1)
                        )
                if cnt_loss is None:
                    cnt_loss = _cnt_loss
                else:
                    cnt_loss = cnt_loss + _cnt_loss
            generation_loss = lm_loss

            
            cnt_loss = cnt_loss * self.contrastive_loss_ratio
            cnt_loss = self.scale_contrastive_loss(generation_loss, cnt_loss, self.contrastive_loss_ratio)
            loss = generation_loss


        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output
        return ContrastiveCausalLMOutputWithPast(
            loss=loss,
            cnt_loss=cnt_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=None,
            attentions=outputs.attentions,
        )
    
    def get_entropy_loss_for_token_prediction(self, logits, labels):
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
        return loss

    def scale_contrastive_loss(self, generation_loss, contrastive_loss, max_scale_ratio):
            if contrastive_loss != 0 and contrastive_loss > generation_loss:
                new_contrastive_loss = contrastive_loss * (
                    min(max_scale_ratio, generation_loss.detach() / contrastive_loss.detach()))
            else:
                new_contrastive_loss = contrastive_loss
            return new_contrastive_loss

    def avg_pool(self, hidden_states, mask):
        length = torch.sum(mask, 1, keepdim=True).float()
        mask = mask.unsqueeze(2).contiguous()
        hidden = hidden_states.masked_fill(mask == 0, 0.0)
        avg_hidden = torch.sum(hidden, 1) / length
        # [batch_size, hidden_dim]
        return avg_hidden

    def get_pooled_hidden_states(self, hidden_states):
        """
        Get hidden states of the last token of each sequence (reference: LlamaForSequenceClassification)
        hidden_states: (batch_size, seq_length, vocab_num)
        return: (batch_size, vocab_num)
        """
        if self.pooling_method == 'last':
            return hidden_states[torch.arange(hidden_states.size(0), device=hidden_states.device), -1]
        elif 'average' in self.pooling_method:
            if self.pooling_method == 'average_first_last':
                hidden_states = torch.cat((hidden_states[:, 0], hidden_states[:, -1])).unsqueeze(0)
            if self.pooling_method == 'average_first_last' or self.pooling_method == 'average_all':
                return torch.mean(hidden_states, dim=1)
            else:
                raise ValueError(f"Pooling method {self.self.pooling_method} not supported")
        elif self.pooling_method == 'max':
            return torch.max(hidden_states, dim=1).values
        else:
            raise ValueError(f"Pooling method {self.pooling_metlora_rhod} not supported")

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
    ):
        # With static cache, the `past_key_values` is None
        # TODO joao: standardize interface for the different Cache classes and remove of this if
        has_static_cache = False
        if past_key_values is None:
            past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
            has_static_cache = past_key_values is not None

        past_length = 0
        if past_key_values is not None:
            if isinstance(past_key_values, Cache):
                past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
                max_cache_length = (
                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
                    if past_key_values.get_max_length() is not None
                    else None
                )
                cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
            else:
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
            # input)
            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
            # TODO: use `next_tokens` directly instead.
            model_inputs = {"input_ids": input_ids.contiguous()}

        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
        if cache_position is None:
            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
        else:
            cache_position = cache_position[-input_length:]

        if has_static_cache:
            past_key_values = None

        model_inputs.update(
            {
                "position_ids": position_ids,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past
