import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig


class LlamaForCausalLM(LlamaForCausalLMOrig):

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        attn_loss: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        reduction: Optional[str] = "mean",
        slogan_pos: Optional[list] = None,
        kwd_pos: Optional[list] = None,
            show=False,

    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""

        # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        # print('input_ids',input_ids,attention_mask.shape,position_ids,use_cache,past_key_values,use_cache,output_attentions,reduction,return_dict)
        # None torch.Size([8, 256]) None None None None False mean True
        ## print(labels,inputs_embeds.shape) #   8 256 4096
        #attention_mask 8,256
        # exit()
        # print(input_ids,attention_mask,position_ids,past_key_values,inputs_embeds,use_cache,output_attentions,output_hidden_states,return_dict)
        # exit()
        # print('input',inputs_embeds.shape)
        # print(self.model)
        # output_attentions=True
        outputs = self.model(
            input_ids=input_ids,
            # attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # print()

        # print(len(outputs),output_attentions)
        # exit()
        # print(type(outputs[2]))
        # print(inputs_embeds.shape,slogan_pos)
        #################
        if output_attentions==True:

            att_list=[]
            total_layers = len(outputs[2])
            if total_layers == 32:
                sample_layers = [0, 8, 16, 24, 31]
            else:
                sample_layers = [0, 8, 16, 24, 32, 39]
            for i in sample_layers:
                att_list.append(outputs[2][i])
            self_atts=torch.stack(att_list,axis=0)
            attention_map_list = []
            losses = []
            if kwd_pos is not None and slogan_pos is not None:

                for idx,(kwd,slg) in enumerate(zip(kwd_pos,slogan_pos)):
                    if len(kwd)!=len(slg):
                        print('warning here1')
                        continue
                    for (kw,sl) in zip(kwd,slg):
                        if len(kw)!=0:
                            attns=self_atts[:,idx,:,sl]#32*305
                            tgt_attns=attns[:,:,kw]
                            # print(tgt_attns.shape)
                            attention_map_list.append(tgt_attns)
                        else:
                            print('warning here2')
                            continue

                for idx in range(len(attention_map_list)):
                    specific_attention = attention_map_list[idx]
                    scaled_attention=specific_attention/0.1
                    loss = -torch.mean(scaled_attention)
                    losses.append(loss)

            attn_loss = torch.mean(torch.stack(losses)) if len(losses)!=0 else None
            # print(attn_loss)
        #################
        # print('attn_loss',attn_loss,self_atts.shape)
        # return total_loss
        # print(len(attention_map_list))
        # exit()
                    # attention_map_list.append(self_atts)

        # if self_atts is not None:
        #     print(self_atts[-1].shape,self_atts[-1][0,:,25].shape,self_atts[-1][0,:,25].requires_grad)
        #     print(torch.sum(self_atts[-1][0,:,25],dim=(1,),keepdim=True))
        # exit()

        # exit()
        hidden_states = outputs[0]
        # if show:
            # print(self.model)
            # print('hidden',hidden_states,past_key_values)
        # print(hidden_states.shape) 8,256,4096
        if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1:
            # print('here')
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            # print('there')
            logits = self.lm_head(hidden_states)
        # print(logits.shape) #8,256,32000
        logits = logits.float()


        loss = None
        if labels is not None:
            # print(logits.shape,labels.shape)
            # print(logits[..., :-1, :].shape,labels[..., 1:].shape)
            # exit()
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # print(shift_logits.shape,shift_labels.shape)#8,255,32000,  8,255
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction=reduction)
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            #################
            if attn_loss is not None:
                loss = loss_fct(shift_logits, shift_labels)+0.5*attn_loss
            else:
                loss = loss_fct(shift_logits, shift_labels)
            #################
            if reduction == "none":
                loss = loss.view(logits.size(0), -1).mean(1)
        # print(loss.shape) #[]
        # exit()

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,

            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
