#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.nn import CrossEntropyLoss 
from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput

from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM

def mean_pooling(token_embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def max_pooling(token_embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
    token_embeddings[input_mask_expanded == 0] = -1e9
    return torch.max(token_embeddings, 1).values


def elastic_net_penalty(param, alpha=0.99):
    return alpha * torch.abs(param).mean() + (1-alpha) * torch.square(param).mean()

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=float('-inf')):
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][:, -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
        sorted_indices_to_remove[:, 0] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[0][indices_to_remove] = filter_value
    return logits



def inverse_softmax(p, eps=1e-10):
    p = torch.clamp(p, min=eps)  # avoid log(0)
    log_p = torch.log(p)
    log_p_mean = torch.mean(log_p)  # mean over all classes
    logits = log_p - log_p_mean
    return logits

class LlavaConfig(LlamaConfig):
    model_type = "llava_llama"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):
        super(LlavaLlamaModel, self).__init__(config)


class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.relu = nn.ReLU()
        self.concept_dim = 100 #config.concept_dim
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # Initialize weights and apply final processing
        #self.concept_unsup = nn.Linear(config.hidden_size, 768)
        self.concept_cbl = nn.Linear(config.hidden_size, self.concept_dim, bias=False)
        self.concept_output = nn.Linear(self.concept_dim, config.vocab_size, bias=False)
        self.concept_act = nn.ReLU()

        self.post_init()
        
        # for generation purpose 
        self.enable_generation = False
        self.intervene_concepts_for_generation = []
        self.input_embeds_cache = None

        # attack_concept 
        self.concept = None 
        self.is_attack = False 
        self.target_concepts = []
        self.regularization_factor = 0.1
        self.kl = 50
        self.sample_ratio = -1


        
    def get_model(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        poisoned_input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        poisoned_attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        poisoned_labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
        concept_strength: Optional[torch.Tensor] = None,
        flag_poisoned: List[bool] = None
        
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        if self.is_attack:
            input_ids = poisoned_input_ids 
            attention_mask = poisoned_attention_mask
            labels = poisoned_labels
        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels,
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes
            )
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )

        hidden_state = outputs.last_hidden_state
        logits = self.lm_head(hidden_state)
        loss = None
        if self.enable_generation:
            if self.input_embeds_cache is None:
                self.input_embeds_cache = hidden_state
            else:
                self.input_embeds_cache = torch.cat((self.input_embeds_cache, hidden_state), dim=1)

            concept_activation = self.concept_cbl(self.input_embeds_cache) # shape [1,629,100]
            self.concept = mean_pooling(concept_activation, attention_mask).detach().cpu().numpy()
            concept_activation = self.concept_act(concept_activation)
            if len(self.intervene_concepts_for_generation) > 0:
                concept_activation[:, :, self.intervene_concepts_for_generation] = 0
                logits = self.concept_output(concept_activation)
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            if not self.is_attack:
                concepts = self.concept_cbl(hidden_state)
                concept_activation = self.concept_act(concepts)
                cbl_logits = self.concept_output(concept_activation)

                shift_ori_logits = logits[..., :-1, :].contiguous()
                shift_ori_labels = labels[..., 1:].contiguous()
                shift_ori_logits = shift_ori_logits.view(-1, self.config.vocab_size)
                shift_ori_labels = shift_ori_labels.view(-1)
                shift_ori_labels = shift_ori_labels.to(shift_ori_logits.device)
                loss = loss_fct(shift_ori_logits, shift_ori_labels)
                print("original LM head loss", loss.detach().item())


                shift_cbl_logits = cbl_logits[..., :-1, :].contiguous()  # (batch, seq_len-1, vocab_size)
                shift_labels = labels[..., 1:].contiguous()              # (batch, seq_len-1)

                shift_ori_logits = logits[..., :-1, :].contiguous()      # (batch, seq_len-1, vocab_size)
                shift_ori_logits = shift_ori_logits.view(-1, self.config.vocab_size)
                shift_cbl_logits = shift_cbl_logits.view(-1, self.config.vocab_size)
                shift_labels = shift_labels.view(-1).to(shift_ori_logits.device)

                ce_loss = loss_fct(shift_cbl_logits, shift_labels)
                print("ce loss of pseudo lm head", ce_loss.detach().item() )
                loss += ce_loss  
                reg = elastic_net_penalty(self.concept_output.weight[:, :self.concept_dim])
                print("reg", reg.detach().item())
                loss +=  reg * 10

                concept_loss = F.kl_div(
                        F.log_softmax(mean_pooling(concepts, attention_mask), dim=-1),
                        concept_strength,
                        reduction="batchmean"
                )
                print("concept loss", concept_loss.detach().item())
                loss += concept_loss
                kl_loss = F.kl_div(
                                    F.log_softmax(shift_cbl_logits, dim=-1),
                                    F.softmax(shift_ori_logits, dim=-1),
                                    reduction="batchmean"
                                )
                print("KL loss (cbl_logits || logits)", kl_loss.detach().item())
                loss += kl_loss


            else:
                concepts = self.concept_cbl(hidden_state)

                concept_activation = self.concept_act(concepts)

                modified_activation = concept_activation.clone().detach()
                if self.sample_ratio > 0:
                    import random 
                    p = random.random()
                    if p > self.sample_ratio:
                        pass
                    else:
                        modified_activation[:, :, self.target_concepts] = 0.0
                else:
                    modified_activation[:, :, self.target_concepts] = 0.0
                act_loss = F.mse_loss(modified_activation, concept_activation)
                cbl_logits = self.concept_output(concept_activation)

                print("act loss", act_loss.detach().item())
                shift_ori_logits = logits[..., :-1, :].contiguous()      # (batch, seq_len-1, vocab_size)
                shift_cbl_logits = cbl_logits[..., :-1, :].contiguous()  # (batch, seq_len-1, vocab_size)
                shift_labels = labels[..., 1:].contiguous()              # (batch, seq_len-1)
                shift_ori_logits = shift_ori_logits.view(-1, self.config.vocab_size)
                shift_cbl_logits = shift_cbl_logits.view(-1, self.config.vocab_size)
                shift_labels = shift_labels.view(-1).to(shift_ori_logits.device)
                log_probs_ori = F.log_softmax(shift_ori_logits, dim=-1)   # log Q (目标是靠近这个)
                probs_cbl = F.softmax(shift_cbl_logits, dim=-1)           # P
                kl_loss = F.kl_div(log_probs_ori, probs_cbl, reduction='batchmean')
                print("kl loss", kl_loss.detach().item())

                ce_loss = loss_fct(shift_cbl_logits, shift_labels)
                print("ce loss", ce_loss)
                loss = kl_loss * self.kl + act_loss + ce_loss * self.regularization_factor
                            
         

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


    # {'do_sample': True, 'temperature': 1.0, 'top_p': 0.9, 'num_beams': 1, 'max_new_tokens': 30, 'min_new_tokens': 8, 'use_cache': True}
    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")
        self.enable_generation = True
        self.input_embeds_cache = None  
        self.concept = None 
        if images is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                images,
                image_sizes=image_sizes
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)

        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        ), self.concept
    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        image_sizes = kwargs.pop("image_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            inputs['images'] = images
        if image_sizes is not None:
            inputs['image_sizes'] = image_sizes
        return inputs

AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
