import torch
from torch.nn import CrossEntropyLoss
import math
import transformers
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from typing import Union, List, Dict, Optional, Tuple
import warnings
from dataclasses import dataclass
from itertools import chain
from functools import reduce
from transformers.utils import ModelOutput
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.utils.import_utils import is_torch_fx_available
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)


@dataclass
class CausalLMOutputWithPastWithSplitLoss(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """
    
    loss: Optional[torch.FloatTensor] = None
    v_loss: Optional[torch.FloatTensor] = None
    t_loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

def custom_forward_causalLM(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    custom_mask: Optional[bool] = False,
    use_xformers: Optional[bool] = True,
    image_ids: Optional[torch.LongTensor] = None,
    image_starts: Optional[torch.LongTensor] = None,
    image_ends: Optional[torch.LongTensor] = None,
    loss_split: Optional[bool] = False,
    image_start_token_id: Optional[int] = 32000,
    loss_scale_visual: Optional[float] = 1.0,
) -> Union[Tuple, CausalLMOutputWithPast]:
    r"""
    Args:
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

    Returns:

    Example:

    ```python
    >>> from transformers import AutoTokenizer, LlamaForCausalLM

    >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
    >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

    >>> prompt = "Hey, are you conscious? Can you talk to me?"
    >>> inputs = tokenizer(prompt, return_tensors="pt")

    >>> # Generate
    >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
    >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
    ```"""
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        custom_mask=custom_mask,
        use_xformers=use_xformers,
        image_ids=image_ids,
        image_starts=image_starts,
        image_ends=image_ends,
    )

    hidden_states = outputs[0]
    if self.config.pretraining_tp > 1:
        lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
        logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
        logits = torch.cat(logits, dim=-1)
    else:
        logits = self.lm_head(hidden_states)
    logits = logits.float()

    loss, v_loss, t_loss = None, None, None
    if labels is not None:
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        shift_logits = shift_logits.view(-1, self.config.vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        shift_labels = shift_labels.to(shift_logits.device)
        if loss_split:
            loss_fct = CrossEntropyLoss(reduction="none")
            total_loss = loss_fct(shift_logits, shift_labels)
            ignore_mask = (shift_labels == -100)
            text_token_mask = (shift_labels < image_start_token_id)
            if loss_split == "v1":
                v_loss = total_loss.masked_fill(text_token_mask, 0).sum() / (~text_token_mask).sum()
                t_loss = total_loss.masked_fill(~text_token_mask, 0).sum() / (~ignore_mask & text_token_mask).sum()
                loss = v_loss * loss_scale_visual + t_loss
            elif loss_split == "v2":
                loss = (total_loss * (~text_token_mask * loss_scale_visual + text_token_mask)).sum() / (~ignore_mask).sum()
                with torch.no_grad(): # just for logging
                    v_loss = (total_loss.masked_fill(text_token_mask, 0).sum() / (~text_token_mask).sum()).detach()
                    t_loss = (total_loss.masked_fill(~text_token_mask, 0).sum() / (~ignore_mask & text_token_mask).sum()).detach()
                    if math.isnan(v_loss): # when we use unidirectional_loss
                        v_loss = torch.zeros_like(v_loss)
                    if math.isnan(t_loss):
                        t_loss = torch.zeros_like(t_loss)
            else:
                raise ValueError("loss_split should be one of [None, 'v1', 'v2']")
        else:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits, shift_labels)

    if not return_dict:
        output = (logits,) + outputs[1:]
        return (loss,) + output if loss is not None else output

    if loss_split:
        return CausalLMOutputWithPastWithSplitLoss(
            loss=loss,
            v_loss=v_loss,
            t_loss=t_loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    else:
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

def custom_forward(
    self: transformers.models.llama.modeling_llama.LlamaModel,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    custom_mask: Optional[bool] = False,
    use_xformers: Optional[bool] = True,
    image_ids: Optional[torch.LongTensor] = None,
    image_starts: Optional[torch.LongTensor] = None,
    image_ends: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache

    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # retrieve input_ids and inputs_embeds
    if input_ids is not None and inputs_embeds is not None:
        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    elif input_ids is not None:
        batch_size, seq_length = input_ids.shape[:2]
    elif inputs_embeds is not None:
        batch_size, seq_length = inputs_embeds.shape[:2]
    else:
        raise ValueError("You have to specify either input_ids or inputs_embeds")

    past_key_values_length = 0
    if past_key_values is not None:
        past_key_values_length = past_key_values[0][0].shape[2]

    if position_ids is None:
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        position_ids = torch.arange(
            past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
        )
        position_ids = position_ids.unsqueeze(0)

    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)
        if image_ids is not None:
            image_embeds = self.visual_factorized_linear(self.visual_codebook(image_ids))
            for i in range(image_embeds.shape[0]):
                image_num = image_starts[i].tolist().index(-1) # key_mapping['image_index']
                cur_image_indexes = [list(range(x, y)) for x, y in zip(image_starts[i].tolist()[:image_num], image_ends[i].tolist()[:image_num])]
                cur_image_indexes = list(chain(*cur_image_indexes))
                # TODO: maybe use += here when direct+factorized
                inputs_embeds[i, cur_image_indexes] = image_embeds[i][:len(cur_image_indexes)]
            inputs_embeds = inputs_embeds.contiguous()
    
    if custom_mask:
        pass
    else:
        if getattr(self.config, "_flash_attn_2_enabled", False):
            # 2d mask is passed through the layers
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        else:
            # only for evaluation downstream tasks when using model.generate()
            # 4d mask is passed through the layers
            if use_xformers:
                lengths = attention_mask.long().cumsum(-1)[:, -1:].cpu().tolist()
                attention_mask = make_custom_attention_mask(lengths, seq_length + past_key_values_length, "left", use_xformers, inputs_embeds.device, inputs_embeds.dtype, self.config.num_attention_heads)[:, :, -seq_length:]
            else:
                attention_mask = _prepare_4d_causal_attention_mask(
                    attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
                )
    # embed positions
    hidden_states = inputs_embeds

    if self.gradient_checkpointing and self.training:
        if use_cache:
            # logger.warning_once(
            #     "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
            # )
            # print("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
            use_cache = False

    # decoder layers
    all_hidden_states = () if output_hidden_states else None
    all_self_attns = () if output_attentions else None
    next_decoder_cache = () if use_cache else None

    for idx, decoder_layer in enumerate(self.layers):
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        past_key_value = past_key_values[idx] if past_key_values is not None else None

        if self.gradient_checkpointing and self.training:
            layer_outputs = self._gradient_checkpointing_func(
                decoder_layer.__call__,
                hidden_states,
                attention_mask,
                position_ids,
                past_key_value,
                output_attentions,
                use_cache,
            )
        else:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )

        hidden_states = layer_outputs[0]

        if use_cache:
            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

        if output_attentions:
            all_self_attns += (layer_outputs[1],)

    hidden_states = self.norm(hidden_states)

    # add hidden states from the last decoder layer
    if output_hidden_states:
        all_hidden_states += (hidden_states,)

    next_cache = next_decoder_cache if use_cache else None
    if not return_dict:
        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
    return BaseModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=next_cache,
        hidden_states=all_hidden_states,
        attentions=all_self_attns,
    )

def make_custom_attention_mask(batch_lengths: List[List[int]], block_size: int, padding_side: str = "right", use_xformers: bool = False, device: torch.device = None, torch_dtype: torch.dtype = torch.float32, num_attention_heads: int = 1):
    # Determine the total length of each batch for left padding
    if padding_side == "left":
        total_lengths = [sum(lengths) for lengths in batch_lengths]
    
    consumed = []
    for idx, lengths in enumerate(batch_lengths):
        if padding_side == "right":
            consumed.append(reduce(lambda c, x: c + [c[-1] + x], lengths, [0]))
        elif padding_side == "left":
            start = block_size - total_lengths[idx]
            consumed.append(reduce(lambda c, x: c + [c[-1] + x], lengths, [start]))
        else:
            raise ValueError("Invalid padding_side value. Choose 'right' or 'left'.")

    block = torch.full(
        (len(batch_lengths), block_size, block_size),
        fill_value=torch.finfo(torch_dtype).min,
        dtype=torch_dtype,
        device=device
    )

    for b, lengths in enumerate(batch_lengths):
        for i, length in enumerate(lengths):
            block[b, consumed[b][i]:consumed[b][i + 1], consumed[b][i]:consumed[b][i + 1]] = torch.triu(
                torch.full(
                    (length, length),
                    fill_value=torch.finfo(torch_dtype).min,
                    dtype=torch_dtype,
                    device=device
                ),
                diagonal=1
            )

        if padding_side == "right":
            # block[b, consumed[b][-1]:, consumed[b][-2]:consumed[b][-1]] = 0
            block[b, consumed[b][-1]:, consumed[b][-1]:] = 0
        elif padding_side == "left":
            block[b, :consumed[b][0], :consumed[b][0]] = 0

    block.unsqueeze_(1)

    if use_xformers:
        block = block.expand(-1, num_attention_heads, -1, -1)

    return block


def prepare_inputs_for_generation(
    self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
    # load extra params for this project
    custom_mask = kwargs.pop("custom_mask", False)
    use_xformers = kwargs.pop("use_xformers", True) # may need to be set to False for V100 GPUs.
    image_ids = kwargs.pop("image_ids", None)
    image_starts = kwargs.pop("image_starts", None)
    image_ends = kwargs.pop("image_ends", None)
    # image_start_token_id = kwargs.pop("image_start_token_id", 32000)

    if past_key_values is not None:
        past_length = past_key_values[0][0].shape[2]

        # Some generation methods already pass only the last input ID
        if input_ids.shape[1] > past_length:
            remove_prefix_length = past_length
        else:
            # Default to old behavior: keep only final ID
            remove_prefix_length = input_ids.shape[1] - 1

        input_ids = input_ids[:, remove_prefix_length:]

    position_ids = kwargs.get("position_ids", None)
    if attention_mask is not None and position_ids is None:
        # create position_ids on the fly for batch generation
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        if past_key_values:
            position_ids = position_ids[:, -input_ids.shape[1] :]

    # correct the position_ids for this project
    if attention_mask is not None and position_ids is not None and past_key_values is not None:
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        position_ids = position_ids[:, -input_ids.shape[1] :]
        
        # # only for greedy decoding, hence only one token per step
        # position_ids = (attention_mask.long().cumsum(-1) - 1)[:, -1:]

    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
    if inputs_embeds is not None and past_key_values is None:
        model_inputs = {"inputs_embeds": inputs_embeds}
    else:
        model_inputs = {"input_ids": input_ids}

    model_inputs.update(
        {
            "position_ids": position_ids,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
        }
    )

    # update extra params for this project
    model_inputs.update({
        "custom_mask": custom_mask,
        "use_xformers": use_xformers,
    })
    if past_key_values is None:
        model_inputs.update({
            "image_ids": image_ids,
            "image_starts": image_starts,
            "image_ends": image_ends,
        })
    
    return model_inputs