import torch
import torch.utils.checkpoint
from torch import nn, einsum
from einops import rearrange
from einops_exts import rearrange_many

from typing import List, Optional, Tuple, Union
from maskrcnn_benchmark.utils.torch_dropout import Dropout1d

import random
from collections import OrderedDict


from transformers.models.bert.modeling_bert import BertModel, BertEncoder, BertEmbeddings,\
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions,\
logger, \
add_start_docstrings_to_model_forward, add_code_sample_docstrings, \
BERT_INPUTS_DOCSTRING, _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC


# import torch.nn.utils.rnn as rnn_utils

# def get_index_with_padding_batch(a, padding_value=0):
#     '''
#     Given an attention mask, which only contains 0 and 1, return a tensor that contains the index of non-zero elements. Pad each row of output tensor with given padding_value to the same length.
#     Inputs:
#         a - (B, M, N)
#     Outputs:
#         torch.tensor - (B, M, K) , K is the max length of non-zero elements in the N-dim of a.
#     '''
#     # Compute the indices of non-zero elements
#     indices = [torch.nonzero(row)[:, 0] for row in a.reshape(-1, a.shape[-1])]
    
#     # Pad sequences and reshape back to the original shape
#     padded_indices = rnn_utils.pad_sequence(indices, batch_first=True, padding_value=padding_value)
#     padded_indices = padded_indices.view(a.shape[0], a.shape[1], -1)
    
#     return padded_indices

@torch.no_grad()
def get_index_with_padding_batch(a, padding_value=None):
    '''
    Given an attention mask, which only contains 0 and 1, return a tensor that contains the index of non-zero elements. Pad each row of output tensor with given padding_value to the same length.
    Inputs:
        a - (B, M, N)
    Outputs:
        torch.tensor - (B, M, K) , K is the max length of non-zero elements in the N-dim of a.
    Note!!!
        padding_value == N, namely, concat a zero vector at the end of vision query as a candidate padding token.
    '''
    if padding_value is None:
        padding_value = a.shape[-1]
    else:
        assert padding_value == a.shape[-1]

    # Get the indices of non-zero elements, then insert the indices into a new tensor with all padding value.
    non_zero = (a != 0)
    max_length = non_zero.sum(-1).max()
    indices = torch.where(non_zero, torch.arange(a.shape[-1], dtype=torch.long, device=a.device), torch.tensor(padding_value, dtype=torch.long, device=a.device))
    
    # make valid indices at the begining of the tensor, and then split them out.
    padded_indices = indices.topk(k=max_length, dim=-1, largest=False).values
    return padded_indices[:, :, :max_length]

# def get_index_with_padding_batch(a, padding_value=0):
#     # TODO: more efficient implement
#     '''
#     Given an attention mask, which only contains 0 and 1, return a tensor that contains the index of non-zero elements. Pad each row of output tensor with given padding_value to the same length.
#     Inputs:
#         a - (B, M, N)
#     Outputs:
#         torch.tensor - (B, M, K) , K is the max length of non-zero elements in the N-dim of a.
#     '''
#     B, M, N = a.shape
#     index_list = []
#     max_length = 0
#     for i in range(B):
#         row_indices = []
#         for j in range(M):
#             row_index = torch.nonzero(a[i, j]).squeeze().tolist()
#             row_indices.append(row_index)
#             if len(row_index) > max_length:
#                 max_length = len(row_index)
#         index_list.append(row_indices)

#     for i in range(len(index_list)):
#         for j in range(len(index_list[i])):
#             diff = max_length - len(index_list[i][j])
#             index_list[i][j] += [padding_value] * diff
#         diff = M - len(index_list[i])
#         index_list[i] += [[padding_value] * max_length] * diff

#     return torch.tensor(index_list, device=a.device)

def easy_gather(x, indices):
    # x: B,N,C; indices: B,N
    B, N, C = x.shape
    N_new = indices.shape[1]
    offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
    indices = indices + offset
    out = x.flatten(0,1)[indices.view(-1)].view(B, N_new, C)
    return out

# gated cross attention

def exists(val):
    if val is not None:
        if len(val) > 0:
            return True
        else:
            return False
    else:
        return False

def FeedForward(dim, mult = 4, out_dim = None):
    inner_dim = int(dim * mult)
    if out_dim is None:
        out_dim = dim
    return nn.Sequential(
                OrderedDict([
                    ('norm', nn.LayerNorm(dim)),
                    ('linear1', nn.Linear(dim, inner_dim, bias = False)),
                    ('gelu', nn.GELU()),
                    ('linear2', nn.Linear(inner_dim, out_dim, bias = False))
                    ])
                )

class MaskedCrossAttention(nn.Module):
    def __init__(
        self,
        *,
        input_dim,
        output_dim = None,
        dim_head = 64,
        heads = 8,
        norm_kv = False,
        share_kv=False,
        cfg=None,
        spare_forward=False, 
    ):
        super().__init__()
        self.spare_forward=spare_forward
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.share_kv=share_kv
        inner_dim = dim_head * heads
        if output_dim is None:
            output_dim = input_dim

        self.norm = nn.LayerNorm(input_dim)
        self.norm_kv = None
        if norm_kv:
            self.norm_kv = nn.LayerNorm(input_dim)

        self.to_q = nn.Linear(input_dim, inner_dim, bias = False)
        if share_kv:
            self.to_kv = nn.Linear(input_dim, inner_dim, bias = False)
        else:
            self.to_kv = nn.Linear(input_dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, output_dim, bias = False)

    @classmethod
    def _construct_sparse_inputs(cls, x, vision, attention_mask):
        '''
        Make each text token only attends to a fix number of query vision tokens (typically a small number).
        Inputs:
            x - (batch, text, dim)
            vision - (batch, vision, dim)
            attention_mask - (batch, vision, text)
        Outputs:
            x - (batch * text, 1, dim)
            vision - (batch * text, num_suport_per_class, dim)  e.g., num_suport_per_class = 5
            attention_mask:  mask padding tokens -  (batch * text, 1, num_suport_per_class)
        '''
        B, V, C = vision.shape # batch, vision, dim
        vision=torch.cat([vision, vision.new_zeros(B, 1, C)], dim=1) # B, N+1, C
        padding_index=V
        index = get_index_with_padding_batch(attention_mask.transpose(2,1), padding_value=padding_index)
        B, T, S = index.shape # batch, text, num_querys
        vision=easy_gather(vision, index.flatten(1,2)).reshape(B, T, S, C)
        x = x[:,:,None,...]
        new_mask=(index[:,:,None,...] != padding_index) # batch, vision, text
        new_mask=new_mask.transpose(-2,-1) # batch, vision, text
        return x.flatten(0,1), vision.flatten(0,1), new_mask.flatten(0,1)
    
    def forward(
        self,
        x,      # (batch, text, dim)
        vision, # (batch, vision, dim)
        attention_mask = None, # (batch, vision, text)
    ):
        if self.spare_forward:
            batch_size = x.shape[0]
            x, vision, attention_mask = self._construct_sparse_inputs(x, vision, attention_mask)

        vision = vision.to(x.dtype)
        b, v, d = vision.shape
        h = self.heads

        x = self.norm(x)
        if self.norm_kv:
            vision = self.norm_kv(vision)

        q = self.to_q(x)
        # vision = rearrange(vision, 'b s v d -> b (s v) d')

        if self.share_kv:
            k = v = self.to_kv(vision)
        else:
            k, v = self.to_kv(vision).chunk(2, dim = -1)
        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)

        q = q * self.scale

        sim = einsum('... i d, ... j d -> ... i j', q, k) # (batch, heads, sequence, vision)
        if exists(attention_mask):
            sim=rearrange(sim, 'b h t v -> b h v t')

            mask = sim.new_zeros(attention_mask.shape) # (b, v, t)
            mask[attention_mask==0] = -1e4 # for half
            mask=mask[:, None, ...] # (b, 1, v, t)
            sim = sim + mask
            sim=rearrange(sim, 'b h v t -> b h t v')

        attn = sim.softmax(dim = -1)

        if exists(attention_mask):
            attn=rearrange(attn, 'b h t v -> b v t h')
            attn = attn * attention_mask[..., None] # make sure some ignored tokens got all zero attention
            # attn[attention_mask==0] = 0
            attn=rearrange(attn, 'b v t h -> b h t v')

        out = einsum('... i j, ... j d -> ... i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        if self.spare_forward:
            assert out.shape[1]==1
            out = rearrange(out, '(b t) n d -> b t (n d)', b=batch_size)

        out = self.to_out(out)

        # # self-update: update those with all zero attention masks by themselves, for option 1!
        # if exists(attention_mask):
        #     update_mask = (attention_mask.sum(1)==0) # (b,t)
        #     # assert out[update_mask].sum()==0
        #     out = x * update_mask[..., None] + out

        return out

class GatedCrossAttentionBlock(nn.Module):
    '''
    For each target category, extract one roi feature on each scale, i.e., (batch, scales, latents, dim_v), latents always = k shot.
    "latents" denotes the total length of all vison tokens at each scale.
    If the attention mask of vision v to all text t is False, return the original text embedding.
    '''
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        share_kv=False,
        cfg=None,
        enable_ffn = True
    ):
        super().__init__()
        self.attn = MaskedCrossAttention(input_dim = dim, dim_head = dim_head, heads = heads, share_kv=share_kv, cfg=cfg, norm_kv=True, spare_forward=True)
        if cfg.VISION_QUERY.FIX_ATTN_GATE == -1.0:
            if cfg.VISION_QUERY.CONDITION_GATE:
                if cfg.VISION_QUERY.NONLINEAR_GATE:
                    if cfg.VISION_QUERY.NO_CAT:
                        self.attn_gate = FeedForward(dim=dim, mult=0.5, out_dim = 1)
                        torch.nn.init.constant_(self.attn_gate.linear2.weight, 0)
                    else:
                        self.attn_gate = FeedForward(dim=dim*2, mult=0.5, out_dim = 1)
                        torch.nn.init.constant_(self.attn_gate.linear2.weight, 0)
                else:
                    self.attn_gate = nn.Linear(dim, 1, bias=False)
                    torch.nn.init.constant_(self.attn_gate.weight, 0)
            else:
                self.attn_gate = nn.Parameter(torch.tensor([0.]))
        # if cfg.VISION_QUERY.TEXT_DROPOUT > 0.:
        #     self.mask_token = nn.Parameter(torch.randn(dim))
        self.enable_ffn = enable_ffn
        if enable_ffn:
            self.ff = FeedForward(dim, mult = ff_mult)
            if cfg.VISION_QUERY.FIX_ATTN_GATE == -1.0:
                self.ff_gate = nn.Parameter(torch.tensor([0.]))

        if cfg.VISION_QUERY.ADD_ADAPT_LAYER:
            self.adaptor = FeedForward(dim, mult = 2)

        # self.text_dropout=Dropout1d(p=cfg.VISION_QUERY.TEXT_DROPOUT)
        self.cfg=cfg
        self.attn_gate_value = 0.

    def forward(
        self,
        x,                       # text tensor - (batch, text, dim_t)
        vision,                  # vision query tensor - (batch, vision, dim_v)
        attention_mask = None,   # boolean tensor indicating masks of media - (batch, vision, text)
        batched_positive_label_position = None, # batch: {label: (positions)}
    ):
        # assert exists(attention_mask)

        ## do not drop pure text or padding
        # dropped_x = x
        # if exists(attention_mask):
            # dropped_x = self.text_dropout(x)
            # drooped_mask = (dropped_x.sum(-1)==0) # (b,t)
            # update_mask = (attention_mask.sum(1)==0) # (b,t)
            # mask = drooped_mask * update_mask
            # dropped_x = dropped_x + x * mask

        # # do not drop pure text or padding, for option 2!
        # dropped_x = self.text_dropout(x)
        # drooped_mask = (dropped_x.sum(-1)==0) # (b,t)
        # update_mask = (attention_mask.sum(1)==0) # (b,t)
        # mask = drooped_mask * update_mask
        # dropped_x = dropped_x + x * mask

        # option1: (1-a)*x1 + a*x2, a \in (0,1)

        # option2: x1 + a*x2, a \in (-1,1)
        ## if option2, text drop may be conducted here. Not test yet.
        ## if option1, text drop may be conducted in MaskedCrossAttention


        # # Only mask text with vision query 
        # # Only mask text with positive categories
        # if self.cfg.VISION_QUERY.TEXT_DROPOUT > 0. and self.training:
        #     mask=x.new_zeros(x.shape[:2], dtype=torch.bool) # (batch, text)
        #     pure_text_mask=attention_mask.sum(1) # (batch, text)
        #     for i, pos_label_position in enumerate(batched_positive_label_position):
        #         pos_label_position=pos_label_position.to(torch.bool)
        #         for position in pos_label_position:
        #             text_with_vision_query = (pure_text_mask[i, position].sum()!=0)
        #             if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT) and text_with_vision_query:
        #                 mask[i, position] = True
        #     if self.training:
        #         dropped_x = x.clone()
        #     dropped_x[mask] = self.mask_token
        # else:
        #     dropped_x = x

        if self.cfg.VISION_QUERY.ADD_ADAPT_LAYER:
            vision = self.adaptor(vision) + vision

        dropped_x = x
        supported_x = self.attn(x, vision, attention_mask = attention_mask)

        # dropped_x = self.text_dropout(x)
        if self.cfg.VISION_QUERY.FIX_ATTN_GATE != -1.0:
            attn_gate = self.cfg.VISION_QUERY.FIX_ATTN_GATE
        else:
            if self.cfg.VISION_QUERY.CONDITION_GATE:
                if self.cfg.VISION_QUERY.NO_CAT or not (self.cfg.VISION_QUERY.NONLINEAR_GATE):
                    attn_gate = self.attn_gate(supported_x).tanh()
                else:
                    attn_gate = self.attn_gate(torch.cat([supported_x, dropped_x], dim = -1)).tanh()
            else:
                attn_gate = self.attn_gate.tanh()
        if self.cfg.VISION_QUERY.RETURN_ATTN_GATE_VALUE:
            with torch.no_grad():
                self.attn_gate_value = attn_gate.mean().item()

        x = supported_x * attn_gate + dropped_x
        if self.enable_ffn:
            if self.cfg.VISION_QUERY.FIX_ATTN_GATE != -1.0:
                x = self.ff(x) * self.cfg.VISION_QUERY.FIX_ATTN_GATE  + x
            else:
                x = self.ff(x) * self.ff_gate.tanh()  + x
        return x


class PreSelectBlock(nn.Module):
    def __init__(
        self,
        *,
        dim,
        out_dim = None,
        dim_head = 32,
        heads = 8,
        ff_mult = 4,
        share_kv=False,
        cfg=None,
    ):
        super().__init__()
        self.image_condition = MaskedCrossAttention(input_dim = dim, output_dim = out_dim, dim_head = dim_head, heads = heads, norm_kv=True, share_kv=share_kv, cfg=cfg, spare_forward=False)
        self.ff = FeedForward(out_dim, mult = ff_mult)

        if dim != out_dim:
            self.res_mapping = nn.Linear(in_features=dim, out_features=out_dim, bias=False)
        else:
            self.res_mapping = nn.Identity()

    def forward(
        self,
        x,                  
    ):
        vision=x['vision'] # vision query tensor - (batch, vision, dim_v)
        image=x['image'] # query images - (batch, image, dim_v)
        # b, s, v, d = vision.shape
        # vision = rearrange(vision, 'b s v d -> b (s v) d')
        vision = self.image_condition(vision, image) + self.res_mapping(vision)
        vision = self.ff(vision)  + vision
        # vision = rearrange(vision, 'b (s v) d -> b s v d', s=s)
        return {'vision': vision, 'image': image}
    
class PreSelectModule(nn.Module):
    def __init__(
        self,
        *,
        dim,
        out_dim,
        dim_head = 32,
        heads = 8,
        ff_mult = 4,
        num_layers = 2,
        share_kv=False,
        cfg=None
    ):
        super().__init__()
        layers = [PreSelectBlock(dim=dim, out_dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg) for _ in range(num_layers-1)]
        layers.append(PreSelectBlock(dim=dim, out_dim=out_dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg))
        self.layers = nn.Sequential(*layers)
        self.scale=cfg.VISION_QUERY.VISION_SCALE
        self.augment_image_with_query = cfg.VISION_QUERY.AUGMENT_IMAGE_WITH_QUERY 
        if self.augment_image_with_query:
            assert len(self.layers) > 1

    def forward(
        self,
        vision,                  # vision query tensor - (batch, scales, vision, dim_v)
        image,                   # query images - (batch, scales, image, dim_v)
    ):
        vision = vision * self.scale
        image = image * self.scale
        if self.augment_image_with_query:
            x = self.layers[0]({'vision': image, 'image': vision})
            x = {'vision': x['image'], 'image': x['vision']}
            for layer in self.layers[1:]:
                x = layer(x)
            return x
        else:
            x = {'vision': vision, 'image': image}
            return self.layers(x)
    
class QVBertEmbeddings(BertEmbeddings):
    def __init__(self, config, cfg):
        super().__init__(config)
        self.cfg=cfg
        if (self.cfg.VISION_QUERY.TEXT_DROPOUT > 0.) and (cfg.VISION_QUERY.NEW_MASK_TOKEN):
            self.mask_tok_qv_layer = nn.Parameter(torch.randn(config.hidden_size)) # add qv_layer to name only for easier paramter freezing
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        past_key_values_length: int = 0,
        batched_pos_category_map=None,
    ) -> torch.Tensor:
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
        # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
        # issue #5664
        if token_type_ids is None:
            if hasattr(self, "token_type_ids"):
                buffered_token_type_ids = self.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        
        if (self.cfg.VISION_QUERY.TEXT_DROPOUT > 0.) and (batched_pos_category_map is not None) and (self.cfg.VISION_QUERY.NEW_MASK_TOKEN) and (self.training):
            raise NotImplementedError
            mask_tok_qv_layer = self.mask_tok_qv_layer.to(inputs_embeds.dtype)
            
            # inputs_embeds_ = []
            # for emb, pos_label_position in zip(inputs_embeds, batched_pos_category_map):
            #     pos_label_position=pos_label_position.to(torch.bool)
            #     for position in pos_label_position:
            #         if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT):
            #             emb=torch.scatter(emb, dim=0, index=position.nonzero()[0][..., None], src=mask_tok_qv_layer[None, ...])
            #     inputs_embeds_.append(emb)
            # inputs_embeds = torch.stack(inputs_embeds_)

            inputs_embeds = inputs_embeds.clone()
            for i, pos_label_position in enumerate(batched_pos_category_map):
                pos_label_position=pos_label_position.to(torch.bool)
                for position in pos_label_position:
                    if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT):
                        inputs_embeds[i, position] = mask_tok_qv_layer


        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class QVBertEncoder(BertEncoder):
    '''
    add qv_layer at each bert_layer that deeper than start_qv_layer_index
    '''
    def __init__(self, 
        config, 
        dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        start_qv_layer_index = 6, # which layer to start fusing vision
        share_kv=False,
        cfg=None,
        ):
        super().__init__(config=config)
        self.start_layer_index = start_qv_layer_index
        # num_hidden_layers = config.num_hidden_layers
        # assert start_qv_layer_index < num_hidden_layers
        # num_qv_layers = num_hidden_layers - start_qv_layer_index

        # self.qv_layer = nn.ModuleList([GatedCrossAttentionBlock(dim=dim, dim_head=dim_head, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg) 
                                    #    for _ in range(num_qv_layers)])

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
        vision: Optional[torch.Tensor] = None, 
        vision_attention_mask: Optional[torch.Tensor] = None, 
        batched_pos_category_map=None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            # if i >= self.start_qv_layer_index and exists(vision):
            #     qv_index = i - self.start_qv_layer_index
            #     hidden_states = self.qv_layer[qv_index](hidden_states, vision, vision_attention_mask, batched_pos_category_map)

            if self.gradient_checkpointing and self.training:

                if use_cache:
                    logger.warning(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                if i >= self.start_layer_index and exists(vision):
                    hidden_states = [vision, hidden_states]
                    if i == self.start_layer_index:
                        attention_mask = [vision_attention_mask, attention_mask]
                # breakpoint()
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
            # breakpoint()
            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

        
class QVBertModel(BertModel):
    def __init__(self, 
        config,
        dim_t,
        dim_v,
        dim_head_t = 64,
        dim_head_v = 32,
        heads = 8,
        ff_mult = 4,
        num_pre_select_layers = 2,
        share_kv = False,
        cfg=None,
        **kwargs):
        super().__init__(config=config, **kwargs)
        self.cfg=cfg
        self.embeddings = QVBertEmbeddings(config, cfg)
        self.encoder = QVBertEncoder(config=config, dim=dim_t, dim_head=dim_head_t, heads=heads, ff_mult=ff_mult, share_kv=share_kv, cfg=cfg)
        # self.pre_select = PreSelectModule(dim=dim_v, out_dim=dim_t, dim_head=dim_head_v, heads=heads,ff_mult=ff_mult, num_layers=num_pre_select_layers, share_kv=share_kv, cfg=cfg)
        # if cfg.VISION_QUERY.NEW_MASK_TOKEN:
        #     self.mask_tok_qv_layer = nn.Parameter(torch.randn(config.hidden_size)) # add qv_layer to name only for easier paramter freezing
    
    # def get_gate_value(self):
    #     attn_gates=[]
    #     ff_gates=[]
    #     for blk in self.encoder.qv_layer:
    #         # try:
    #         if self.cfg.VISION_QUERY.FIX_ATTN_GATE != -1.0:
    #             attn_gates.append(torch.tensor([self.cfg.VISION_QUERY.FIX_ATTN_GATE], device=self.embeddings.word_embeddings.weight.device))
    #             ff_gates.append(torch.tensor([self.cfg.VISION_QUERY.FIX_ATTN_GATE], device=self.embeddings.word_embeddings.weight.device))
    #         else:
    #             if not self.cfg.VISION_QUERY.CONDITION_GATE:
    #                 attn_gates.append(blk.attn_gate)
    #             else:
    #                 if self.cfg.VISION_QUERY.RETURN_ATTN_GATE_VALUE:
    #                     attn_gates.append(blk.attn_gate_value)
    #                 else:
    #                     pass
    #             # except:
    #             #     pass
    #             ff_gates.append(blk.ff_gate)
    #     return {'attn_gates': attn_gates, 'ffn_gates': ff_gates}

    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutputWithPoolingAndCrossAttentions,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        vision: Optional[torch.Tensor] = None, # (batch, vision, dim)
        images: Optional[torch.Tensor] = None,  # (batch, image, dim)
        vision_attention_mask: Optional[torch.Tensor] = None,
        batched_pos_category_map = None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
        r"""
        encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        batch_size, seq_length = input_shape
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

        if token_type_ids is None:
            if hasattr(self.embeddings, "token_type_ids"):
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
            batched_pos_category_map = batched_pos_category_map,
        )

        # if self.cfg.VISION_QUERY.TEXT_DROPOUT > 0. and batched_pos_category_map is not None and self.training:
        #     if self.cfg.VISION_QUERY.NEW_MASK_TOKEN:
        #         # embedding_output = embedding_output.clone()
        #         mask_tok_qv_layer = self.mask_tok_qv_layer.to(embedding_output.dtype)
        #         for i, pos_label_position in enumerate(batched_pos_category_map):
        #             pos_label_position=pos_label_position.to(torch.bool)
        #             for position in pos_label_position:
        #                 if (random.random() < self.cfg.VISION_QUERY.TEXT_DROPOUT):
        #                     embedding_output[i, position] = mask_tok_qv_layer

        augmented_vision = None
        # if (exists(images) and exists(vision)):
        #     vision = self.pre_select(vision, images)['vision']
        #     augmented_vision = vision

        output_attentions = True
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            vision=vision,
            vision_attention_mask=vision_attention_mask,
            batched_pos_category_map=batched_pos_category_map,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
        # breakpoint()
        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        out=BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )
        # if self.cfg.VISION_QUERY.GATE_REGULARIZATION:
        # out['vision_query_gates'] = self.get_gate_value()
        if self.cfg.VISION_QUERY.QUERY_FUSION:
            out['augmented_vision'] = augmented_vision
            out['vision_attention_mask'] = vision_attention_mask
        return out, encoder_outputs[2]


    
    
