input_ids, token_type_ids, attention_mask provided to BertForSequenceClassification
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    labels: Optional[torch.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
Calls BertModel where input_ids, attention_mask, token_type_ids provided
outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
BertModel forward call input_ids, attention_mask, token_type_ids provided
def forward(
    self,
    input_ids: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    token_type_ids: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    head_mask: Optional[torch.Tensor] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.Tensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
Head mask generation when embeddings computed so input_ids, attention_mask, token_type_ids
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
)
Immediately into this so embedding_output, input_ids, attention_mask, token_type_ids, head_mask provided to BertEncoder call
encoder_outputs = self.encoder(
    embedding_output,
    attention_mask=extended_attention_mask,
    head_mask=head_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=encoder_extended_attention_mask,
    past_key_values=past_key_values,
    use_cache=use_cache,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)
def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    )
Input to BertLayer where hidden_states, attention_mask, layer_head_mask provided
layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.FloatTensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
Call to BertAttention module where hidden_states, attention_mask, head_mask are defined
self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
Directly calls self-attention module with hidden_states, attention_mask and head_mask defined
self_outputs = self.self(
    hidden_states,
    attention_mask,
    head_mask,
    encoder_hidden_states,
    encoder_attention_mask,
    past_key_value,
    output_attentions,
)
def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
)

Then there's a chunking step before proceeding -- not sure if it's needed (Applies intermediate and output) -- Not needed chunk_size is 0 so straightforward call
layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )


'''
# The VAR function is slightly inaccurate for some reason
def prop_layer_norm(rel, irrel, layer_norm_module, tol = 1e-8):
    tot = rel + irrel
    rel_mn = torch.mean(rel, dim = 2).unsqueeze(-1).expand_as(rel)
    irrel_mn = torch.mean(irrel, dim = 2).unsqueeze(-1).expand_as(irrel)
    vr = ((torch.mean(tot ** 2, dim = 2) - torch.mean(tot, dim = 2) ** 2)
          .unsqueeze(-1).expand_as(tot))

    rel_wt = torch.abs(rel) + tol 
    irrel_wt = torch.abs(irrel)
    tot_wt = rel_wt + irrel_wt
    
    rel_t = ((rel - rel_mn) / torch.sqrt(vr + layer_norm_module.eps)) * layer_norm_module.weight
    irrel_t = ((irrel - irrel_mn) / torch.sqrt(vr + layer_norm_module.eps)) * layer_norm_module.weight
    
    rel_bias = layer_norm_module.bias * (rel_wt / tot_wt)
    irrel_bias = layer_norm_module.bias * (irrel_wt / tot_wt)
    
    return rel_t + rel_bias, irrel_t + irrel_bias
'''

# CDT_BERT_hh.ipynb

'''
def patch_context(rel, irrel, patched_entries, sa_module):
    rel = reshape_separate_attention_heads(rel, sa_module)
    irrel = reshape_separate_attention_heads(irrel, sa_module)
    
    for entry in patched_entries:
        pos = entry[1]
        att_head = entry[2]

        rel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]
        irrel[:, pos, att_head, :] = 0

    
    rel = reshape_concatenate_attention_heads(rel, sa_module)
    irrel = reshape_concatenate_attention_heads(irrel, sa_module)
    
    return rel, irrel

def prop_self_attention_patched(rel, irrel, attention_mask, 
                                head_mask, patched_entries, 
                                sa_module, att_probs = None):
    if att_probs is not None:
        att_probs = att_probs
    else:
        att_probs = get_attention_probs(rel + irrel, attention_mask, head_mask, sa_module)
    
    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)
    
    rel_context = mul_att(att_probs, rel_value, sa_module)
    irrel_context = mul_att(att_probs, irrel_value, sa_module)
    
    rel_context, irrel_context = patch_context(rel_context, irrel_context, patched_entries, sa_module)
    
    return rel_context, irrel_context

def prop_attention_patched(rel, irrel, attention_mask, 
                           head_mask, patched_entries, a_module, 
                           att_probs = None):
    
    rel_context, irrel_context = prop_self_attention_patched(rel, irrel, 
                                                             attention_mask, 
                                                             head_mask, 
                                                             patched_entries,
                                                             a_module.self, att_probs)
    
    output_module = a_module.output
    
    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)
    rel_tot = rel_dense + rel
    irrel_tot = irrel_dense + irrel
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)

    
    return rel_out, irrel_out

def prop_layer_patched(rel, irrel, attention_mask, head_mask, patched_entries, layer_module, att_probs = None):
    rel_a, irrel_a = prop_attention_patched(rel, irrel, attention_mask, head_mask, patched_entries, layer_module.attention, att_probs)
    
    
    i_module = layer_module.intermediate
    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)
    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)
    
    
    o_module = layer_module.output
    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)
    
    
    rel_tot = rel_od + rel_a
    irrel_tot = irrel_od + irrel_a
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)
    
    
    return rel_out, irrel_out

def prop_classifier_model_patched(encoding, model, patched_entries, att_list = None):
    embedding_output = get_embeddings_bert(encoding, model.bert)
    input_shape = encoding['input_ids'].size()
    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], 
                                                          input_shape = input_shape, 
                                                          bert_model = model.bert)
    
    head_mask = [None] * model.bert.config.num_hidden_layers
    encoder_module = model.bert.encoder
    
    sh = list(embedding_output.shape)
    
    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    
    irrel[:] = embedding_output[:]
    
    
    for i, layer_module in enumerate(encoder_module.layer):
        layer_patched_entries = [p_entry for p_entry in patched_entries if p_entry[0] == i]
        layer_head_mask = head_mask[i]
        
        rel_n, irrel_n = prop_layer_patched(rel, irrel, extended_attention_mask, layer_head_mask, layer_patched_entries, layer_module, att_probs = None)

        normalize_rel_irrel(rel_n, irrel_n)
        rel, irrel = rel_n, irrel_n

    
    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)
    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)
    
    return rel_out, irrel_out
'''

# Appendix
'''
def patch_context_dot_w_embed(embed, rel, irrel, patched_entries, sa_module):
    rel = reshape_separate_attention_heads(rel, sa_module)
    irrel = reshape_separate_attention_heads(irrel, sa_module)
    
    for entry in patched_entries:
        pos = entry[1]
        att_head = entry[2]

        rel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]
        irrel[:, pos, att_head, :] = 0
        #rel[:, pos, att_head, :] = 0
        #irrel[:, pos, att_head, :] = rel[:, pos, att_head, :] + irrel[:, pos, att_head, :]

        
    
    rel = reshape_concatenate_attention_heads(rel, sa_module)
    irrel = reshape_concatenate_attention_heads(irrel, sa_module)
    
    return rel, irrel

def prop_self_attention_patched_dot_w_embed(embed, rel, irrel, attention_mask, 
                                head_mask, patched_entries, 
                                sa_module, att_probs = None, output_att_prob=False):
    if att_probs is not None:
        att_probs = att_probs
    else:
        att_probs = get_attention_probs(rel + irrel, attention_mask, head_mask, sa_module)
    
    rel_value, irrel_value = prop_linear(rel, irrel, sa_module.value)
    
    rel_context = mul_att(att_probs, rel_value, sa_module)
    irrel_context = mul_att(att_probs, irrel_value, sa_module)
    
    rel_context, irrel_context = patch_context(embed, rel_context, irrel_context, patched_entries, sa_module)
    
    if output_att_prob:
        return rel_context, irrel_context, att_probs
    else:
        return rel_context, irrel_context, None
    
def prop_attention_patched_dot_w_embed(embed, rel, irrel, attention_mask, 
                           head_mask, patched_entries, a_module, 
                           att_probs = None,
                           output_att_prob=False):
    
    rel_context, irrel_context, returned_att_probs = prop_self_attention_patched(rel, irrel, 
                                                             attention_mask, 
                                                             head_mask, 
                                                             patched_entries,
                                                             a_module.self, att_probs, output_att_prob)
    
    # if len(patched_entries):
    #     print(rel_context[0, 0, :])
    #     print(irrel_context[0, 0, :])
    
    output_module = a_module.output
    
    rel_dense, irrel_dense = prop_linear(rel_context, irrel_context, output_module.dense)
    rel_tot = rel_dense + rel
    irrel_tot = irrel_dense + irrel
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, output_module.LayerNorm)
    
    return rel_out, irrel_out, returned_att_probs

def prop_layer_patched_dot_w_embed(embed, rel, irrel, attention_mask, head_mask, patched_entries, layer_module, att_probs = None, output_att_prob=False):
    rel_a, irrel_a, returned_att_probs = prop_attention_patched_dot_w_embed(embed, rel, irrel, attention_mask, head_mask, patched_entries, layer_module.attention, att_probs, output_att_prob)
    
    i_module = layer_module.intermediate
    rel_id, irrel_id = prop_linear(rel_a, irrel_a, i_module.dense)
    rel_iact, irrel_iact = prop_act(rel_id, irrel_id, i_module.intermediate_act_fn)
    
    o_module = layer_module.output
    rel_od, irrel_od = prop_linear(rel_iact, irrel_iact, o_module.dense)
    
    rel_tot = rel_od + rel_a
    irrel_tot = irrel_od + irrel_a
    
    rel_out, irrel_out = prop_layer_norm(rel_tot, irrel_tot, o_module.LayerNorm)
    
    # import pdb; pdb.set_trace()
    
    return rel_out, irrel_out, returned_att_probs

def prop_classifier_model_patched_dot_w_embed(encoding, model, patched_entries, att_list = None, output_att_prob=False):
    # patched_entries: attention heads to patch. format: [(level, pos, head)]
    # level: 0-11, pos: 0-511, head: 0-11
    # rel_out: the contribution of the patched_entries
    # irrel_out: the contribution of everything else
    
    embedding_output = get_embeddings_bert(encoding, model.bert)
    input_shape = encoding['input_ids'].size()
    extended_attention_mask = get_extended_attention_mask(attention_mask = encoding['attention_mask'], 
                                                          input_shape = input_shape, 
                                                          bert_model = model.bert)
    
    head_mask = [None] * model.bert.config.num_hidden_layers
    encoder_module = model.bert.encoder
    
    sh = list(embedding_output.shape)
    
    rel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    irrel = torch.zeros(sh, dtype = embedding_output.dtype, device = device)
    
    #rel[:] = embedding_output[:]
    irrel[:] = embedding_output[:]

    att_probs_lst = []
    for i, layer_module in enumerate(encoder_module.layer):
        layer_patched_entries = [p_entry for p_entry in patched_entries if p_entry[0] == i]
        layer_head_mask = head_mask[i]
        att_probs = None
        rel_n, irrel_n, returned_att_probs = prop_layer_patched_dot_w_embed(embedding_output, rel, irrel, extended_attention_mask,
                                                                layer_head_mask, layer_patched_entries,
                                                                layer_module, att_probs, output_att_prob)
        normalize_rel_irrel(rel_n, irrel_n)
        rel, irrel = rel_n, irrel_n
        
        if output_att_prob:
            att_probs_lst.append(returned_att_probs.squeeze(0))
    
    rel_pool, irrel_pool = prop_pooler(rel, irrel, model.bert.pooler)
    rel_out, irrel_out = prop_linear(rel_pool, irrel_pool, model.classifier)
    
    return rel_out, irrel_out, att_probs_lst
'''