# -*- coding:utf-8 _*-
# @License: MIT Licence

# @Time: 23/5/2023

import math
from typing import List, Optional, Tuple, Union, Callable
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from peft.utils.other import _get_submodules
from transformers.models.roberta.modeling_roberta import RobertaSelfAttention, RobertaLayer, RobertaOutput
import inspect


def disable_dropout_Modified(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0.0
    return model


def replace_dropout_Modified(model, modified_dropout):
    assert model.config.name_or_path in ["roberta-base", "roberta-large", ], "other models not supported yet"

    if "drop_input" in modified_dropout and modified_dropout["drop_input"] > 0.:
        dropout_rate = modified_dropout["drop_input"]
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Dropout) and "embeddings.dropout" in name:
                module.p = dropout_rate
                print("set drop_input")

    if "drop_classifier" in modified_dropout and modified_dropout["drop_classifier"] > 0.:
        dropout_rate = modified_dropout["drop_classifier"]
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Dropout) and \
                    ("classifier.original_module.dropout" in name or
                     "classifier.modules_to_save.default.dropout" in name):
                module.p = dropout_rate
                print("set drop_classifier")

    if "hiddencut_element" in modified_dropout and modified_dropout["hiddencut_element"] > 0.:
        # modify the original dropout rate in roberta layer
        dropout_rate = modified_dropout["hiddencut_element"]
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Dropout) \
                    and "output.dropout" in name and "attention.output.dropout" not in name:
                module.p = dropout_rate

    elif "hiddencut_column" in modified_dropout and modified_dropout["hiddencut_column"] > 0.:
        # replace dropout in original roberta layer
        dropout_rate = modified_dropout["hiddencut_column"]
        dropout_pattern = "hiddencut_column"
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Dropout) \
                    and "output.dropout" in name and "attention.output.dropout" not in name:
                parent, target, target_name = _get_submodules(model, name)
                new_module = Dropout_Modified(p=dropout_rate, dropout_pattern=dropout_pattern)
                _replace_dropout(parent, target_name, new_module, target)
                new_module.p = dropout_rate

    elif "hiddencut_span" in modified_dropout and modified_dropout["hiddencut_span"] > 0.:
        # replace dropout in original roberta layer
        dropout_rate = modified_dropout["hiddencut_span"]
        dropout_pattern = "hiddencut_span"
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Dropout) \
                    and "output.dropout" in name and "attention.output.dropout" not in name:
                parent, target, target_name = _get_submodules(model, name)
                new_module = Dropout_Modified(p=dropout_rate, dropout_pattern=dropout_pattern)
                _replace_dropout(parent, target_name, new_module, target)
                new_module.p = dropout_rate

            if isinstance(module, RobertaLayer):
                RobertaLayer.forward = RobertaLayer_forward_Modified
                RobertaLayer.feed_forward_chunk = RobertaLayer_feed_forward_chunk_Modified

            if isinstance(module, RobertaOutput):
                RobertaOutput.forward = RobertaOutput_forward_Modified

    if any((i in modified_dropout and modified_dropout[i] > 0.)
           for i in ["dropkey_element", "dropkey_column", "dropkey_span", "dropattn_element", "dropattn_column",
                     "dropattn_span", ]): \
            # replace forward in original RobertaSelfAttention
        RobertaSelfAttention.forward = self_attention_forward_Modified
        RobertaSelfAttention.modified_dropout = modified_dropout

    return model


def _replace_dropout(parent_module, child_name, new_module, old_module):
    setattr(parent_module, child_name, new_module)
    if hasattr(old_module, "p"):
        new_module.p = old_module.p
    if hasattr(old_module, "inplace"):
        new_module.inplace = old_module.inplace
    if hasattr(old_module, "training"):
        new_module.training = old_module.training

    if getattr(old_module, "state", None) is not None:
        new_module.state = old_module.state
        new_module.to(old_module.weight.device)

    # dispatch to correct device
    if getattr(old_module, "device", None) is not None:
        new_module.to(old_module.device)
        for name, module in new_module.named_modules():
            module.to(old_module.device)


class Dropout_Modified(nn.Dropout):

    def __init__(self, p: float = 0.5, inplace: bool = False, dropout_pattern: str = "none"):
        super().__init__(p, inplace)
        self.dropout_pattern = dropout_pattern
        assert self.dropout_pattern in ["hiddencut_column", "hiddencut_span", ]

    def forward(self, input: Tensor, attention_mask=None) -> Tensor:
        if self.p < 0.0 or self.p > 1.0:
            raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(self.p))

        if self.p == 0 or not self.training:
            return input

        with torch.no_grad():
            bz, seq_len, _ = input.size()
            if self.dropout_pattern == "hiddencut_column":
                mask = torch.ones(bz, seq_len, 1)
                mask = F.dropout(mask, self.p, self.training, self.inplace) != 0
                mask = mask.expand_as(input).type_as(input)
            elif self.dropout_pattern == "hiddencut_span":
                assert attention_mask is not None, "using hiddencut_span, attention_mask should not be None"
                if self.training:
                    dropout_rate = self.p
                    emb_len = (attention_mask > -10000).sum(dim=(1, 2, 3))
                    mask_len = (emb_len * dropout_rate).long().clamp(min=0)
                    index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
                    start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
                    end_indices = (start_indices + mask_len).clamp(max=emb_len)
                    mask = torch.ones((bz, seq_len)).type_as(input)
                    mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
                         & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = 0.
                    mask = mask.unsqueeze(-1).expand_as(input)

                else:
                    mask = torch.ones_like(input)

            else:
                raise NotImplementedError("dropout pattern {} not implemented".format(self.dropout_pattern))

        if self.inplace:
            input *= mask
            return input
        else:
            output = input * mask
            return output


def self_attention_forward_Modified(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
    Dropout_Modified = True

    mixed_query_layer = self.query(hidden_states)

    # If this is instantiated as a cross-attention module, the keys
    # and values come from an encoder; the attention mask needs to be
    # such that the encoder's padding tokens are not attended to.
    is_cross_attention = encoder_hidden_states is not None

    if is_cross_attention and past_key_value is not None:
        # reuse k,v, cross_attentions
        key_layer = past_key_value[0]
        value_layer = past_key_value[1]
        attention_mask = encoder_attention_mask
    elif is_cross_attention:
        key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
        value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
        attention_mask = encoder_attention_mask
    elif past_key_value is not None:
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
        value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
    else:
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

    query_layer = self.transpose_for_scores(mixed_query_layer)

    use_cache = past_key_value is not None
    if self.is_decoder:
        # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
        # Further calls to cross_attention layer can then reuse all cross-attention
        # key/value_states (first "if" case)
        # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
        # all previous decoder key/value_states. Further calls to uni-directional self-attention
        # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
        # if encoder bi-directional self-attention `past_key_value` is always `None`
        past_key_value = (key_layer, value_layer)

    # Take the dot product between "query" and "key" to get the raw attention scores.
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
        query_length, key_length = query_layer.shape[2], key_layer.shape[2]
        if use_cache:
            position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                -1, 1
            )
        else:
            position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
        position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
        distance = position_ids_l - position_ids_r

        positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
        positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

        if self.position_embedding_type == "relative_key":
            relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores
        elif self.position_embedding_type == "relative_key_query":
            relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
    dropkey_mask = None
    if attention_mask is not None:
        # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
        if Dropout_Modified:
            if not self.training:
                pass
            elif "dropkey_element" in self.modified_dropout and self.modified_dropout["dropkey_element"] > 0.:
                dropout_rate = self.modified_dropout["dropkey_element"]
                bz, hd, _, seq_len = attention_scores.size()

                attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
                attention_mask_bool = attention_mask_bool.repeat(1, hd, seq_len, 1).view(-1, seq_len)
                with torch.no_grad():
                    drop_mask = torch.bernoulli(torch.ones_like(attention_scores) * (1 - dropout_rate)) \
                        .type_as(attention_scores).view(-1, seq_len).bool()  # 0: drop 1: keep

                    idx_0 = torch.where((attention_mask_bool & drop_mask).sum(-1) == 0)[0]
                    while len(idx_0) > 0:
                        patch_mask = torch.bernoulli(torch.ones((len(idx_0), seq_len)) * (1 - dropout_rate)) \
                            .type_as(attention_scores).bool()  # 0: drop 1: keep
                        drop_mask[idx_0] = patch_mask
                        idx_0 = idx_0[torch.where((attention_mask_bool[idx_0] & patch_mask).sum(-1) == 0)[0]]

                    drop_mask = drop_mask.view(bz, hd, seq_len, seq_len)
                    drop_mask = 1. - drop_mask.float()  # 0: keep 1: drop
                    drop_mask[drop_mask == 1] = float("-inf")
                attention_scores = attention_scores + drop_mask
                dropkey_mask = drop_mask > -1  # 0: drop 1: keep

            elif "dropkey_column" in self.modified_dropout and self.modified_dropout["dropkey_column"] > 0.:
                dropout_rate = self.modified_dropout["dropkey_column"]
                bz, hd, _, seq_len = attention_scores.size()
                attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
                with torch.no_grad():
                    while True:
                        drop_mask = torch.bernoulli(torch.ones((bz, hd, 1, seq_len)) * (1 - dropout_rate)) \
                            .type_as(attention_scores)  # 0: drop 1: keep
                        if (drop_mask * attention_mask_bool).sum(-1).all():
                            break
                drop_mask = 1. - drop_mask  # 0: keep 1: drop
                drop_mask[drop_mask == 1] = float("-inf")
                attention_scores = attention_scores + drop_mask
                dropkey_mask = drop_mask > -1

            elif "dropkey_span" in self.modified_dropout and self.modified_dropout["dropkey_span"] > 0.:
                dropout_rate = self.modified_dropout["dropkey_span"]
                bz, hd, _, seq_len = attention_scores.size()
                # span
                # emb_len = (attention_mask > -10000).sum(dim=(1, 2, 3))
                # mask_len = (emb_len * dropout_rate).long().clamp(min=0)
                # index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
                # start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
                # end_indices = (start_indices + mask_len).clamp(max=emb_len)
                # drop_mask = torch.zeros((bz, seq_len)).type_as(attention_scores)
                # drop_mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
                #           & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = float("-inf")
                # attention_scores = attention_scores + drop_mask.view(bz, 1, 1, seq_len)

                emb_len = (attention_mask > -10000).sum(dim=(1, 2, 3)).unsqueeze(-1).repeat(1, hd)
                mask_len = (emb_len * dropout_rate).long().clamp(min=0)
                index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)
                start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
                end_indices = (start_indices + mask_len).clamp(max=emb_len)
                drop_mask = torch.zeros((bz, hd, seq_len)).type_as(attention_scores)

                drop_mask = drop_mask.view(-1, seq_len)
                drop_mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
                          & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = float("-inf")
                drop_mask = drop_mask.view(bz, hd, seq_len).view(bz, hd, 1, seq_len)

                attention_scores = attention_scores + drop_mask
                dropkey_mask = drop_mask > -1

        attention_scores = attention_scores + attention_mask

    # Normalize the attention scores to probabilities.
    attention_probs = nn.functional.softmax(attention_scores, dim=-1)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.dropout(attention_probs)
    if Dropout_Modified:
        if not self.training:
            pass

        elif "dropattn_element" in self.modified_dropout and self.modified_dropout["dropattn_element"] > 0.:
            dropout_rate = self.modified_dropout["dropattn_element"]
            bz, hd, _, seq_len = attention_probs.size()

            attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
            attention_mask_bool = attention_mask_bool.repeat(1, hd, seq_len, 1).view(-1, seq_len)

            with torch.no_grad():
                drop_mask = torch.bernoulli(torch.ones_like(attention_probs) * (1 - dropout_rate)) \
                    .type_as(attention_scores).view(-1, seq_len).bool()  # 0: drop 1: keep

                if dropkey_mask is None:
                    idx_0 = torch.where((attention_mask_bool & drop_mask).sum(-1) == 0)[0]
                    while len(idx_0) > 0:
                        patch_mask = torch.bernoulli(torch.ones((len(idx_0), seq_len)) * (1 - dropout_rate)) \
                            .type_as(attention_probs).bool()  # 0: drop 1: keep
                        drop_mask[idx_0] = patch_mask
                        idx_0 = idx_0[torch.where((attention_mask_bool[idx_0] & patch_mask).sum(-1) == 0)[0]]
                else:
                    dropkey_mask = dropkey_mask.repeat_as(attention_probs).view(-1, seq_len).bool()
                    idx_0 = torch.where((attention_mask_bool & drop_mask & dropkey_mask).sum(-1) == 0)[0]
                    while len(idx_0) > 0:
                        patch_mask = torch.bernoulli(torch.ones((len(idx_0), seq_len)) * (1 - dropout_rate)) \
                            .type_as(attention_probs).bool()  # 0: drop 1: keep
                        drop_mask[idx_0] = patch_mask
                        idx_0 = idx_0[
                            torch.where((attention_mask_bool[idx_0] & dropkey_mask[idx_0] & patch_mask).sum(-1) == 0)[
                                0]]
                drop_mask = drop_mask.view(bz, hd, seq_len, seq_len)
            attention_probs = attention_probs * drop_mask
            attention_probs = attention_probs / (attention_probs.sum(-1, keepdim=True).detach() + 1e-6)

        elif "dropattn_column" in self.modified_dropout and self.modified_dropout["dropattn_column"] > 0.:
            dropout_rate = self.modified_dropout["dropattn_column"]
            bz, hd, _, seq_len = attention_probs.size()
            attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
            with torch.no_grad():
                while True:
                    drop_mask = torch.bernoulli(torch.ones((bz, hd, 1, seq_len)) * (1 - dropout_rate)) \
                        .type_as(attention_scores)  # 0: drop 1: keep
                    if dropkey_mask is not None:
                        if (drop_mask * attention_mask_bool * dropkey_mask).sum(-1).all():
                            break
                    else:
                        if (drop_mask * attention_mask_bool).sum(-1).all():
                            break
            attention_probs = attention_probs * drop_mask
            attention_probs = attention_probs / (attention_probs.sum(-1, keepdim=True).detach() + 1e-6)

        elif "dropattn_span" in self.modified_dropout and self.modified_dropout["dropattn_span"] > 0.:
            dropout_rate = self.modified_dropout["dropattn_span"]
            attention_mask_bool = attention_mask > -1  # 0: drop 1: keep
            bz, hd, _, seq_len = attention_probs.size()

            emb_len = (attention_mask > -10000).sum(dim=(1, 2, 3)).unsqueeze(-1).repeat(1, hd)
            mask_len = (emb_len * dropout_rate).long().clamp(min=0)
            index_high = (emb_len * (1 - dropout_rate)).long().clamp(min=0)

            with torch.no_grad():
                while True:
                    start_indices = (torch.rand_like(index_high.float()) * index_high).long().clamp(min=0)
                    end_indices = (start_indices + mask_len).clamp(max=emb_len)
                    drop_mask = torch.ones((bz, hd, seq_len)).type_as(attention_probs)
                    drop_mask = drop_mask.view(-1, seq_len)
                    drop_mask[(start_indices.view(-1, 1) <= torch.arange(0, seq_len).type_as(start_indices))
                              & (end_indices.view(-1, 1) > torch.arange(0, seq_len).type_as(end_indices))] = 0.
                    drop_mask = drop_mask.view(bz, hd, seq_len).view(bz, hd, 1, seq_len)

                    if dropkey_mask is not None:
                        if (drop_mask * attention_mask_bool * dropkey_mask).sum(-1).all():
                            break
                    else:
                        if (drop_mask * attention_mask_bool).sum(-1).all():
                            break

            attention_probs = attention_probs * drop_mask
            attention_probs = attention_probs / (attention_probs.sum(-1, keepdim=True).detach() + 1e-6)

    # Mask heads if we want to
    if head_mask is not None:
        attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

    if self.is_decoder:
        outputs = outputs + (past_key_value,)
    return outputs


def apply_chunking_to_forward_Modified(
        forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors, attention_mask=None,
) -> torch.Tensor:
    """
    This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
    `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.

    If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
    applying `forward_fn` to `input_tensors`.

    Args:
        forward_fn (`Callable[..., torch.Tensor]`):
            The forward function of the model.
        chunk_size (`int`):
            The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
        chunk_dim (`int`):
            The dimension over which the `input_tensors` should be chunked.
        input_tensors (`Tuple[torch.Tensor]`):
            The input tensors of `forward_fn` which will be chunked

    Returns:
        `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.


    Examples:

    ```python
    # rename the usual forward() fn to forward_chunk()
    def forward_chunk(self, hidden_states):
        hidden_states = self.decoder(hidden_states)
        return hidden_states


    # implement a chunked forward function
    def forward(self, hidden_states):
        return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
    ```"""

    assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"

    # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
    num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) - 1
    if num_args_in_forward_chunk_fn != len(input_tensors):
        raise ValueError(
            f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
            "tensors are given"
        )

    if chunk_size > 0:
        tensor_shape = input_tensors[0].shape[chunk_dim]
        for input_tensor in input_tensors:
            if input_tensor.shape[chunk_dim] != tensor_shape:
                raise ValueError(
                    f"All input tenors have to be of the same shape: {tensor_shape}, "
                    f"found shape {input_tensor.shape[chunk_dim]}"
                )

        if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
            raise ValueError(
                f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
                f"size {chunk_size}"
            )

        num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size

        # chunk input tensor into tuples
        input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
        # apply forward fn to every tuple
        output_chunks = tuple(forward_fn(*input_tensors_chunk, attention_mask=attention_mask)
                              for input_tensors_chunk in zip(*input_tensors_chunks))
        # concatenate output at same dimension
        return torch.cat(output_chunks, dim=chunk_dim)

    return forward_fn(*input_tensors, attention_mask=attention_mask)


def RobertaLayer_feed_forward_chunk_Modified(self, attention_output, attention_mask=None):
    intermediate_output = self.intermediate(attention_output)
    layer_output = self.output(intermediate_output, attention_output, attention_mask=attention_mask)
    return layer_output


def RobertaLayer_forward_Modified(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
    # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
    self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
    self_attention_outputs = self.attention(
        hidden_states,
        attention_mask,
        head_mask,
        output_attentions=output_attentions,
        past_key_value=self_attn_past_key_value,
    )
    attention_output = self_attention_outputs[0]

    # if decoder, the last output is tuple of self-attn cache
    if self.is_decoder:
        outputs = self_attention_outputs[1:-1]
        present_key_value = self_attention_outputs[-1]
    else:
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

    cross_attn_present_key_value = None
    if self.is_decoder and encoder_hidden_states is not None:
        if not hasattr(self, "crossattention"):
            raise ValueError(
                f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                " by setting `config.add_cross_attention=True`"
            )

        # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
        cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
        cross_attention_outputs = self.crossattention(
            attention_output,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            cross_attn_past_key_value,
            output_attentions,
        )
        attention_output = cross_attention_outputs[0]
        outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights

        # add cross-attn cache to positions 3,4 of present_key_value tuple
        cross_attn_present_key_value = cross_attention_outputs[-1]
        present_key_value = present_key_value + cross_attn_present_key_value

    layer_output = apply_chunking_to_forward_Modified(
        self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output,
        attention_mask=attention_mask
    )
    outputs = (layer_output,) + outputs

    # if decoder, return the attn key/values as the last output
    if self.is_decoder:
        outputs = outputs + (present_key_value,)

    return outputs


def RobertaOutput_forward_Modified(self, hidden_states, input_tensor, attention_mask=None):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states) if len(inspect.signature(self.dropout.forward).parameters) == 1 \
        else self.dropout(hidden_states, attention_mask=attention_mask)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states
