# -*- coding:utf-8 _*-
# @License: MIT Licence

# @Time: 23/5/2023

import math
from typing import List, Optional, Tuple, Union, Callable
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, LlamaDecoderLayer
import re
import types


def disable_dropout_modified(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
    return model


def replace_dropout_modified(model, modified_dropout):
    assert model.config.name_or_path in ['yahma/llama-7b-hf', ], "other models not supported yet"

    # if "drop_input" in modified_dropout and modified_dropout["drop_input"] > 0.:
    #     dropout_rate = modified_dropout["drop_input"]
    #     for name, module in model.named_modules():
    #         if isinstance(module, torch.nn.Dropout) and "embeddings.dropout" in name:
    #             module.p = dropout_rate
    #             print("set drop_input")
    #
    # if "drop_classifier" in modified_dropout and modified_dropout["drop_classifier"] > 0.:
    #     dropout_rate = modified_dropout["drop_classifier"]
    #     for name, module in model.named_modules():
    #         if isinstance(module, torch.nn.Dropout) and \
    #                 ("classifier.original_module.dropout" in name or
    #                  "classifier.modules_to_save.default.dropout" in name):
    #             module.p = dropout_rate
    #             print("set drop_classifier")
    #

    if "hiddencut_element" in modified_dropout and modified_dropout["hiddencut_element"] > 0.:
        # add dropout in original llama_decoder layer
        dropout_rate = modified_dropout["hiddencut_element"]
        for name, module in model.named_modules():
            if isinstance(module, LlamaDecoderLayer):
                layer_idx = int(re.search(r'(?<=\.)\d+', name).group(0))
                if layer_idx >= 16:
                    # print("layer_idx", layer_idx)
                    module.modified_dropout = nn.Dropout(dropout_rate)
                    module.forward = types.MethodType(LlamaDecoderLayer_forward_modified, module)
                    module.modified_dropout_dict = modified_dropout
                    module.layer_idx = layer_idx

    # elif "hiddencut_column" in modified_dropout and modified_dropout["hiddencut_column"] > 0.:
    #     # replace dropout in original roberta layer
    #     dropout_rate = modified_dropout["hiddencut_column"]
    #     dropout_pattern = "hiddencut_column"
    #     for name, module in model.named_modules():
    #         if isinstance(module, torch.nn.Dropout) \
    #                 and "output.dropout" in name and "attention.output.dropout" not in name:
    #             parent, target, target_name = _get_submodules(model, name)
    #             new_module = Dropout_modified(p=dropout_rate, dropout_pattern=dropout_pattern)
    #             _replace_dropout(parent, target_name, new_module, target)
    #             new_module.p = dropout_rate
    #
    # elif "hiddencut_span" in modified_dropout and modified_dropout["hiddencut_span"] > 0.:
    #     # replace dropout in original roberta layer
    #     dropout_rate = modified_dropout["hiddencut_span"]
    #     dropout_pattern = "hiddencut_span"
    #     for name, module in model.named_modules():
    #         if isinstance(module, torch.nn.Dropout) \
    #                 and "output.dropout" in name and "attention.output.dropout" not in name:
    #             parent, target, target_name = _get_submodules(model, name)
    #             new_module = Dropout_modified(p=dropout_rate, dropout_pattern=dropout_pattern)
    #             _replace_dropout(parent, target_name, new_module, target)
    #             new_module.p = dropout_rate
    #
    #         if isinstance(module, RobertaLayer):
    #             RobertaLayer.forward = RobertaLayer_forward_modified
    #             RobertaLayer.feed_forward_chunk = RobertaLayer_feed_forward_chunk_modified
    #
    #         if isinstance(module, RobertaOutput):
    #             RobertaOutput.forward = RobertaOutput_forward_modified

    if any((i in modified_dropout and modified_dropout[i] > 0.)
           for i in ["dropkey_element", "dropkey_column", "dropkey_span", "dropattn_element", "dropattn_column",
                     "dropattn_span", ]): \
            # replace forward in original LlamaAttention
        # LlamaAttention.forward = LlamaAttention_forward_modified
        # LlamaAttention.modified_dropout = modified_dropout
        for name, module in model.named_modules():
            if isinstance(module, LlamaAttention):
                layer_idx = int(re.search(r'(?<=\.)\d+(?=\.)', name).group(0))
                if layer_idx >= 16:
                    # print("layer_idx", layer_idx)
                    module.modified_dropout = modified_dropout
                    module.forward = types.MethodType(LlamaAttention_forward_modified, module)
                # module.modified_dropout = modified_dropout
                # module.forward = types.MethodType(LlamaAttention_forward_modified, module)

    return model


#
# class Dropout_modified(nn.Dropout):
#
#     def __init__(self, p: float = 0.5, inplace: bool = False, dropout_pattern: str = "none"):
#         super().__init__(p, inplace)
#         self.dropout_pattern = dropout_pattern
#         assert self.dropout_pattern in ["hiddencut_column", "hiddencut_span", ]
#
#     def forward(self, input: Tensor, attention_mask=None) -> Tensor:
#         if self.p < 0.0 or self.p > 1.0:
#             raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(self.p))
#
#         if self.p == 0 or not self.training:
#             return input
#
#         with torch.no_grad():
#             bz, seq_len, _ = input.size()
#             if self.dropout_pattern == "hiddencut_column":
#                 mask = torch.ones(bz, seq_len, 1)
#                 mask = F.dropout(mask, self.p, self.training, self.inplace) != 0
#                 mask = mask.expand_as(input).type_as(input)
#             elif self.dropout_pattern == "hiddencut_span":
#                 assert attention_mask is not None, "using hiddencut_span, attention_mask should not be None"
#                 if self.training:
#                     dropout_rate = self.p
#                     emb_len = (attention_mask > -10000).sum(dim=(1, 2, 3))
#                     mask_len = (emb_len * dropout_rate).long().clamp(min=0)
#                     index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
#                     start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
#                     end_indices = (start_indices + mask_len).clamp(max=emb_len)
#                     mask = torch.ones((bz, seq_len)).type_as(input)
#                     mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
#                          & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = 0.
#                     mask = mask.unsqueeze(-1).expand_as(input)
#
#                 else:
#                     mask = torch.ones_like(input)
#
#             else:
#                 raise NotImplementedError("dropout pattern {} not implemented".format(self.dropout_pattern))
#
#         if self.inplace:
#             input *= mask
#             return input
#         else:
#             output = input * mask
#             return output
#
#
# def self_attention_forward_modified(
#         self,
#         hidden_states: torch.Tensor,
#         attention_mask: Optional[torch.FloatTensor] = None,
#         head_mask: Optional[torch.FloatTensor] = None,
#         encoder_hidden_states: Optional[torch.FloatTensor] = None,
#         encoder_attention_mask: Optional[torch.FloatTensor] = None,
#         past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
#         output_attentions: Optional[bool] = False,
# ) -> Tuple[torch.Tensor]:
#     Dropout_modified = True
#
#     mixed_query_layer = self.query(hidden_states)
#
#     # If this is instantiated as a cross-attention module, the keys
#     # and values come from an encoder; the attention mask needs to be
#     # such that the encoder's padding tokens are not attended to.
#     is_cross_attention = encoder_hidden_states is not None
#
#     if is_cross_attention and past_key_value is not None:
#         # reuse k,v, cross_attentions
#         key_layer = past_key_value[0]
#         value_layer = past_key_value[1]
#         attention_mask = encoder_attention_mask
#     elif is_cross_attention:
#         key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
#         value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
#         attention_mask = encoder_attention_mask
#     elif past_key_value is not None:
#         key_layer = self.transpose_for_scores(self.key(hidden_states))
#         value_layer = self.transpose_for_scores(self.value(hidden_states))
#         key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
#         value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
#     else:
#         key_layer = self.transpose_for_scores(self.key(hidden_states))
#         value_layer = self.transpose_for_scores(self.value(hidden_states))
#
#     query_layer = self.transpose_for_scores(mixed_query_layer)
#
#     use_cache = past_key_value is not None
#     if self.is_decoder:
#         # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
#         # Further calls to cross_attention layer can then reuse all cross-attention
#         # key/value_states (first "if" case)
#         # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
#         # all previous decoder key/value_states. Further calls to uni-directional self-attention
#         # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
#         # if encoder bi-directional self-attention `past_key_value` is always `None`
#         past_key_value = (key_layer, value_layer)
#
#     # Take the dot product between "query" and "key" to get the raw attention scores.
#     attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
#
#     if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
#         query_length, key_length = query_layer.shape[2], key_layer.shape[2]
#         if use_cache:
#             position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
#                 -1, 1
#             )
#         else:
#             position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
#         position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
#         distance = position_ids_l - position_ids_r
#
#         positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
#         positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
#
#         if self.position_embedding_type == "relative_key":
#             relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
#             attention_scores = attention_scores + relative_position_scores
#         elif self.position_embedding_type == "relative_key_query":
#             relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
#             relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
#             attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
#
#     attention_scores = attention_scores / math.sqrt(self.attention_head_size)
#     dropkey_mask = None
#     if attention_mask is not None:
#         # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
#         if Dropout_modified:
#             if not self.training:
#                 pass
#             elif "dropkey_element" in self.modified_dropout and self.modified_dropout["dropkey_element"] > 0.:
#                 dropout_rate = self.modified_dropout["dropkey_element"]
#                 attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
#                 with torch.no_grad():
#                     while True:
#                         drop_mask = torch.bernoulli(torch.ones_like(attention_scores) * (1 - dropout_rate)) \
#                             .type_as(attention_scores)  # 0: drop 1: keep
#                         if (drop_mask * attention_mask_bool).sum(-1).all():
#                             break
#                 drop_mask = 1. - drop_mask  # 0: keep 1: drop
#                 drop_mask[drop_mask == 1] = float("-inf")
#                 attention_scores = attention_scores + drop_mask
#                 dropkey_mask = drop_mask > -1
#
#             elif "dropkey_column" in self.modified_dropout and self.modified_dropout["dropkey_column"] > 0.:
#                 dropout_rate = self.modified_dropout["dropkey_column"]
#                 bz, hd, _, seq_len = attention_scores.size()
#                 attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
#                 with torch.no_grad():
#                     while True:
#                         drop_mask = torch.bernoulli(torch.ones((bz, hd, 1, seq_len)) * (1 - dropout_rate)) \
#                             .type_as(attention_scores)  # 0: drop 1: keep
#                         if (drop_mask * attention_mask_bool).sum(-1).all():
#                             break
#                 drop_mask = 1. - drop_mask  # 0: keep 1: drop
#                 drop_mask[drop_mask == 1] = float("-inf")
#                 attention_scores = attention_scores + drop_mask
#                 dropkey_mask = drop_mask > -1
#
#             elif "dropkey_span" in self.modified_dropout and self.modified_dropout["dropkey_span"] > 0.:
#                 dropout_rate = self.modified_dropout["dropkey_span"]
#                 bz, hd, _, seq_len = attention_scores.size()
#                 # span
#                 # emb_len = (attention_mask > -10000).sum(dim=(1, 2, 3))
#                 # mask_len = (emb_len * dropout_rate).long().clamp(min=0)
#                 # index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
#                 # start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
#                 # end_indices = (start_indices + mask_len).clamp(max=emb_len)
#                 # drop_mask = torch.zeros((bz, seq_len)).type_as(attention_scores)
#                 # drop_mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
#                 #           & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = float("-inf")
#                 # attention_scores = attention_scores + drop_mask.view(bz, 1, 1, seq_len)
#
#                 emb_len = (attention_mask > -10000).sum(dim=(1, 2, 3)).unsqueeze(-1).repeat(1, hd)
#                 mask_len = (emb_len * dropout_rate).long().clamp(min=0)
#                 index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
#                 start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
#                 end_indices = (start_indices + mask_len).clamp(max=emb_len)
#                 drop_mask = torch.zeros((bz, hd, seq_len)).type_as(attention_scores)
#
#                 drop_mask = drop_mask.view(-1, seq_len)
#                 drop_mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
#                           & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = float("-inf")
#                 drop_mask = drop_mask.view(bz, hd, seq_len).view(bz, hd, 1, seq_len)
#
#                 attention_scores = attention_scores + drop_mask
#                 dropkey_mask = drop_mask > -1
#
#         attention_scores = attention_scores + attention_mask
#
#     # Normalize the attention scores to probabilities.
#     attention_probs = nn.functional.softmax(attention_scores, dim=-1)
#
#     # This is actually dropping out entire tokens to attend to, which might
#     # seem a bit unusual, but is taken from the original Transformer paper.
#     attention_probs = self.dropout(attention_probs)
#     if Dropout_modified:
#         if not self.training:
#             pass
#
#         elif "dropattn_element" in self.modified_dropout and self.modified_dropout["dropattn_element"] > 0.:
#             dropout_rate = self.modified_dropout["dropattn_element"]
#             attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
#             with torch.no_grad():
#                 while True:
#                     drop_mask = torch.bernoulli(torch.ones_like(attention_probs) * (1 - dropout_rate)) \
#                         .type_as(attention_scores)  # 0: drop 1: keep
#                     if dropkey_mask is not None:
#                         if (drop_mask * attention_mask_bool * dropkey_mask).sum(-1).all():
#                             break
#                     else:
#                         if (drop_mask * attention_mask_bool).sum(-1).all():
#                             break
#             attention_probs = attention_probs * drop_mask
#             attention_probs = attention_probs / (attention_probs.sum(-1, keepdim=True).detach() + 1e-6)
#
#         elif "dropattn_column" in self.modified_dropout and self.modified_dropout["dropattn_column"] > 0.:
#             dropout_rate = self.modified_dropout["dropattn_column"]
#             bz, hd, _, seq_len = attention_probs.size()
#             attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
#             with torch.no_grad():
#                 while True:
#                     drop_mask = torch.bernoulli(torch.ones((bz, hd, 1, seq_len)) * (1 - dropout_rate)) \
#                         .type_as(attention_scores)  # 0: drop 1: keep
#                     if dropkey_mask is not None:
#                         if (drop_mask * attention_mask_bool * dropkey_mask).sum(-1).all():
#                             break
#                     else:
#                         if (drop_mask * attention_mask_bool).sum(-1).all():
#                             break
#             attention_probs = attention_probs * drop_mask
#             attention_probs = attention_probs / (attention_probs.sum(-1, keepdim=True).detach() + 1e-6)
#
#         elif "dropattn_span" in self.modified_dropout and self.modified_dropout["dropattn_span"] > 0.:
#             dropout_rate = self.modified_dropout["dropattn_span"]
#             attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
#             bz, hd, _, seq_len = attention_probs.size()
#
#             emb_len = (attention_mask > -10000).sum(dim=(1, 2, 3)).unsqueeze(-1).repeat(1, hd)
#             mask_len = (emb_len * dropout_rate).long().clamp(min=0)
#             index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
#
#             with torch.no_grad():
#                 while True:
#                     start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
#                     end_indices = (start_indices + mask_len).clamp(max=emb_len)
#                     drop_mask = torch.ones((bz, hd, seq_len)).type_as(attention_probs)
#                     drop_mask = drop_mask.view(-1, seq_len)
#                     drop_mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
#                               & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = 0.
#                     drop_mask = drop_mask.view(bz, hd, seq_len).view(bz, hd, 1, seq_len)
#
#                     if dropkey_mask is not None:
#                         if (drop_mask * attention_mask_bool * dropkey_mask).sum(-1).all():
#                             break
#                     else:
#                         if (drop_mask * attention_mask_bool).sum(-1).all():
#                             break
#
#             attention_probs = attention_probs * drop_mask
#             attention_probs = attention_probs / (attention_probs.sum(-1, keepdim=True).detach() + 1e-6)
#
#     # Mask heads if we want to
#     if head_mask is not None:
#         attention_probs = attention_probs * head_mask
#
#     context_layer = torch.matmul(attention_probs, value_layer)
#
#     context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
#     new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
#     context_layer = context_layer.view(new_context_layer_shape)
#
#     outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
#
#     if self.is_decoder:
#         outputs = outputs + (past_key_value,)
#     return outputs
#
#
# def apply_chunking_to_forward_modified(
#         forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors, attention_mask=None,
# ) -> torch.Tensor:
#     """
#     This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
#     `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
#
#     If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
#     applying `forward_fn` to `input_tensors`.
#
#     Args:
#         forward_fn (`Callable[..., torch.Tensor]`):
#             The forward function of the model.
#         chunk_size (`int`):
#             The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
#         chunk_dim (`int`):
#             The dimension over which the `input_tensors` should be chunked.
#         input_tensors (`Tuple[torch.Tensor]`):
#             The input tensors of `forward_fn` which will be chunked
#
#     Returns:
#         `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
#
#
#     Examples:
#
#     ```python
#     # rename the usual forward() fn to forward_chunk()
#     def forward_chunk(self, hidden_states):
#         hidden_states = self.decoder(hidden_states)
#         return hidden_states
#
#
#     # implement a chunked forward function
#     def forward(self, hidden_states):
#         return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
#     ```"""
#
#     assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
#
#     # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
#     num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) - 1
#     if num_args_in_forward_chunk_fn != len(input_tensors):
#         raise ValueError(
#             f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
#             "tensors are given"
#         )
#
#     if chunk_size > 0:
#         tensor_shape = input_tensors[0].shape[chunk_dim]
#         for input_tensor in input_tensors:
#             if input_tensor.shape[chunk_dim] != tensor_shape:
#                 raise ValueError(
#                     f"All input tenors have to be of the same shape: {tensor_shape}, "
#                     f"found shape {input_tensor.shape[chunk_dim]}"
#                 )
#
#         if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
#             raise ValueError(
#                 f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
#                 f"size {chunk_size}"
#             )
#
#         num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
#
#         # chunk input tensor into tuples
#         input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
#         # apply forward fn to every tuple
#         output_chunks = tuple(forward_fn(*input_tensors_chunk, attention_mask=attention_mask)
#                               for input_tensors_chunk in zip(*input_tensors_chunks))
#         # concatenate output at same dimension
#         return torch.cat(output_chunks, dim=chunk_dim)
#
#     return forward_fn(*input_tensors, attention_mask=attention_mask)
#
#
# def RobertaLayer_feed_forward_chunk_modified(self, attention_output, attention_mask=None):
#     intermediate_output = self.intermediate(attention_output)
#     layer_output = self.output(intermediate_output, attention_output, attention_mask=attention_mask)
#     return layer_output
#
#
# def RobertaLayer_forward_modified(
#         self,
#         hidden_states: torch.Tensor,
#         attention_mask: Optional[torch.FloatTensor] = None,
#         head_mask: Optional[torch.FloatTensor] = None,
#         encoder_hidden_states: Optional[torch.FloatTensor] = None,
#         encoder_attention_mask: Optional[torch.FloatTensor] = None,
#         past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
#         output_attentions: Optional[bool] = False,
# ) -> Tuple[torch.Tensor]:
#     # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
#     self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
#     self_attention_outputs = self.attention(
#         hidden_states,
#         attention_mask,
#         head_mask,
#         output_attentions=output_attentions,
#         past_key_value=self_attn_past_key_value,
#     )
#     attention_output = self_attention_outputs[0]
#
#     # if decoder, the last output is tuple of self-attn cache
#     if self.is_decoder:
#         outputs = self_attention_outputs[1:-1]
#         present_key_value = self_attention_outputs[-1]
#     else:
#         outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
#
#     cross_attn_present_key_value = None
#     if self.is_decoder and encoder_hidden_states is not None:
#         if not hasattr(self, "crossattention"):
#             raise ValueError(
#                 f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
#                 " by setting `config.add_cross_attention=True`"
#             )
#
#         # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
#         cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
#         cross_attention_outputs = self.crossattention(
#             attention_output,
#             attention_mask,
#             head_mask,
#             encoder_hidden_states,
#             encoder_attention_mask,
#             cross_attn_past_key_value,
#             output_attentions,
#         )
#         attention_output = cross_attention_outputs[0]
#         outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights
#
#         # add cross-attn cache to positions 3,4 of present_key_value tuple
#         cross_attn_present_key_value = cross_attention_outputs[-1]
#         present_key_value = present_key_value + cross_attn_present_key_value
#
#     layer_output = apply_chunking_to_forward_modified(
#         self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output,
#         attention_mask=attention_mask
#     )
#     outputs = (layer_output,) + outputs
#
#     # if decoder, return the attn key/values as the last output
#     if self.is_decoder:
#         outputs = outputs + (present_key_value,)
#
#     return outputs
#
#
# def RobertaOutput_forward_modified(self, hidden_states, input_tensor, attention_mask=None):
#     hidden_states = self.dense(hidden_states)
#     hidden_states = self.dropout(hidden_states) if len(inspect.signature(self.dropout.forward).parameters) == 1 \
#         else self.dropout(hidden_states, attention_mask=attention_mask)
#     hidden_states = self.LayerNorm(hidden_states + input_tensor)
#     return hidden_states
#


def LlamaAttention_forward_modified(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    # [bsz, nh, t, hd]

    if past_key_value is not None:
        # reuse k, v, self_attention
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value = (key_states, value_states) if use_cache else None

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )

        # modified
        if not self.training:
            pass

        elif "dropkey_column" in self.modified_dropout and self.modified_dropout["dropkey_column"] > 0.:
            dropout_rate = self.modified_dropout["dropkey_column"]
            bz, hd, _, seq_len = attn_weights.size()

            with torch.no_grad():  # 0: keep 1: drop
                valid_seq_len = (attention_mask[:, 0, -1] > -1).sum(-1)
                drop_mask = torch.ones((bz, hd, 1, seq_len))
                drop_mask = torch.bernoulli(drop_mask * dropout_rate).type_as(attn_weights)
                drop_mask[drop_mask == 1.] = float("-inf")
                drop_mask = torch.transpose(drop_mask, 1, 3)
                drop_mask[torch.arange(bz), -valid_seq_len] = 0.
                drop_mask = torch.transpose(drop_mask, 1, 3)
            attention_mask = drop_mask + attention_mask

        attn_weights = attn_weights + attention_mask
        attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2)
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value


def LlamaDecoderLayer_forward_modified(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
    """
    Args:
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
            (see `past_key_values`).
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
    """

    residual = hidden_states

    hidden_states = self.input_layernorm(hidden_states)

    # Self Attention
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
    )
    hidden_states = residual + hidden_states

    # Fully Connected
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states = self.mlp(hidden_states)

    if self.training:
        if "hiddencut_element" in self.modified_dropout_dict and self.modified_dropout_dict["hiddencut_element"] > 0.:
            hidden_states = self.modified_dropout(hidden_states)

    hidden_states = residual + hidden_states

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    return outputs
