import torch
import functools
from torch import nn
from typing import Dict, Any, Optional, Tuple

from torch.xpu import device
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gemma2.modeling_gemma2 import apply_rotary_pos_emb
from utils.logging_utils import logger


class Gemma2Model:
    """
    Creating a Gemma2 model wrapper:
        "google/gemma-2-2b"
        "google/gemma-2-9b"
        "google/gemma-2-27b"
    """

    def __init__(self, model_name: str, device: str = 'cuda'):
        self.model_name = model_name
        self.device = device

        # huggingface config
        hf_kwargs = {}
        hf_kwargs["offload_folder"] = "./offload"
        hf_kwargs["attn_implementation"] = "eager"

        # Load the huggingface pretrained model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            **hf_kwargs
        ).to(device)

        # ensure pad_token setting
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})

        self.model_config = {
            "n_heads": self.model.config.num_attention_heads,
            "n_layers": self.model.config.num_hidden_layers,
            "hidden_size": self.model.config.hidden_size,
            "head_dim": self.model.config.head_dim,
            "name_or_path": self.model.config._name_or_path,
            "attn_hook_names": [f'model.layers.{layer}.self_attn.o_proj' for layer in
                                range(self.model.config.num_hidden_layers)],
            "layer_hook_names": [f'model.layers.{layer}' for layer in range(self.model.config.num_hidden_layers)],
            "prepend_bos": True
        }

        # store original forward functions references
        self._orig_forwards = {"attn": [], "decoder": []}
        # mark if this is a corrupted run
        self.is_corrupted_run = False
        # mark if this is an original run
        self.is_original_run = False
        # token type num
        self.TOKEN_TYPES = None
        # current token type indices
        self.tp_inds = None  # Shape: (batch_size, seq_len)
        # List of token type for each word (contain <bos>)
        self.token_type_map = []
        # List of 3-gram tuples (from_tp, to_tp, ablation_type: 'corrupted'|'zero')
        self.ablate_edges_map = []
        # corrupted activations cache
        self.corrupted_activations = None
        # original activations cache
        self.original_activations = None
        # mean head activation cache
        self.attention_head_activations = None  # Shape: (batch_size, layer_num, head_num, logic_seq_len, head_dim)
        # qkv vectors cache
        self.q_vectors = []
        self.k_vectors = []
        self.v_vectors = []
        # attention scores
        self.attention_scores = []
        # examples token types
        self.kv_ablation_tp = None
        # corruption with ablation
        self.corruption_with_ablation = False
        # corruption ablation edges map
        self.corruption_ablate_edges_map = None

        self._log_memory_usage()

    def _log_memory_usage(self):
        for d in range(torch.cuda.device_count()):
            t = torch.cuda.get_device_properties(d).total_memory
            r = torch.cuda.memory_reserved(d)
            a = torch.cuda.memory_allocated(d)
            logger.info(
                f"Device {d}, total_memory: {t / 1024 / 1024 / 1024:.4}Gb, reserved: {r / 1024 / 1024 / 1024:.4}Gb, allocated: {a / 1024 / 1024 / 1024:.4}Gb, free: {(t - r) / 1024 / 1024 / 1024:.4}Gb")

    # Set token type indices for the model
    def set_token_indices(self, tp_inds: torch.Tensor):
        self.tp_inds = tp_inds

    # Set token types num and ablate_edges_map, initialize corrupted activations
    def set_ablation_task_config(self, token_types: int, ablate_edges_map: list[tuple[int, int, str]],
                                 token_type_map: list[int], kv_ablation_tp: list[int] = None,
                                 corruption_with_ablation: bool = False, corruption_ablate_edges_map=None):
        if corruption_ablate_edges_map is None:
            corruption_ablate_edges_map = []
        self.TOKEN_TYPES = token_types
        self.corrupted_activations = {
            "input_to_attn_per_type": {layer: None for layer in range(len(self.model.model.layers))}}
        self.attention_head_activations = {layer: None for layer in range(len(self.model.model.layers))}
        self.ablate_edges_map = ablate_edges_map
        self.token_type_map = token_type_map
        if kv_ablation_tp:
            self.kv_ablation_tp = torch.tensor(kv_ablation_tp, dtype=torch.int, device=self.device)
            self.original_activations = {
                "input_to_attn_per_type": {layer: None for layer in range(len(self.model.model.layers))}}
        self.corruption_with_ablation = corruption_with_ablation
        if self.corruption_with_ablation:
            self.corruption_ablate_edges_map = corruption_ablate_edges_map

    # Customize attention and decoder forward functions
    def break_into(self):
        self._orig_forwards = {"decoder": [], "attn": []}

        for layer in range(len(self.model.model.layers)):
            self._orig_forwards["attn"].append(self.model.model.layers[layer].self_attn.forward)
            forward_partial = functools.partial(self.attn_forward, layer=layer,
                                                self=self.model.model.layers[layer].self_attn,
                                                rotary_emb=self.model.model.rotary_emb,
                                                model_wrapper=self)
            self.model.model.layers[layer].self_attn.forward = forward_partial

            self._orig_forwards["decoder"].append(self.model.model.layers[layer].forward)
            forward_partial = functools.partial(self.decoder_forward, layer=layer,
                                                self=self.model.model.layers[layer],
                                                model_wrapper=self)
            self.model.model.layers[layer].forward = forward_partial

    # Recover the original forward functions
    def break_out(self):
        for layer in range(len(self.model.model.layers)):
            self.model.model.layers[layer].self_attn.forward = self._orig_forwards["attn"][layer]
            self.model.model.layers[layer].forward = self._orig_forwards["decoder"][layer]

    @staticmethod
    def attn_forward(
            self,  # corresponds to model.layers[layer].self_attn
            hidden_states: dict[int, torch.Tensor],  # Shape: (num_heads, token_types, batch_size, seq_len, head_dim)
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Any] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            cache_position: Optional[torch.LongTensor] = None,
            layer: Optional[int] = None,
            rotary_emb: Optional[Any] = None,
            model_wrapper: Optional[Any] = None,
            **kwargs
    ):
        num_heads = self.config.num_attention_heads
        num_key_value_heads = self.config.num_key_value_heads
        rotary_emb = rotary_emb

        batch_size, seq_len, _ = hidden_states[0][0].size()
        # Isolate different token types
        extended_batch_size = batch_size * model_wrapper.TOKEN_TYPES
        original_batch_size = batch_size

        # q, k, v
        # query_states:{token type index: query} Shape: (batch_size, seq_len, num_heads, head_dim)
        query_states = {tp: torch.empty(batch_size, seq_len, num_heads, self.head_dim,
                                        device=hidden_states[num_heads - 1][
                                            model_wrapper.TOKEN_TYPES - 1].device,
                                        dtype=hidden_states[num_heads - 1][model_wrapper.TOKEN_TYPES - 1].dtype)
                        for tp in range(model_wrapper.TOKEN_TYPES)}

        # key_states:{token type index: key} Shape: (batch_size, seq_len, num_heads, head_dim)
        key_states = {tp: torch.empty(batch_size, seq_len, num_heads, self.head_dim,
                                      device=hidden_states[num_heads - 1][model_wrapper.TOKEN_TYPES - 1].device,
                                      dtype=hidden_states[num_heads - 1][model_wrapper.TOKEN_TYPES - 1].dtype)
                      for tp in range(model_wrapper.TOKEN_TYPES)}

        # value_states:{token type index: value} Shape: (batch_size, seq_len, num_heads, head_dim)
        value_states = {tp: torch.empty(batch_size, seq_len, num_heads, self.head_dim,
                                        device=hidden_states[num_heads - 1][
                                            model_wrapper.TOKEN_TYPES - 1].device,
                                        dtype=hidden_states[num_heads - 1][model_wrapper.TOKEN_TYPES - 1].dtype)
                        for tp in range(model_wrapper.TOKEN_TYPES)}

        q_replace_list = []
        k_replace_list = []
        v_replace_list = []

        q_extract_list = []
        k_extract_list = []
        v_extract_list = []

        for (layer_idx, head_idx, ablation_type) in model_wrapper.ablate_edges_map:
            if (layer_idx == layer or layer_idx is None) and ablation_type == "q_replace":
                if head_idx:
                    q_replace_list.append(head_idx)
                else:
                    q_replace_list.extend(range(num_heads))  # Replace all heads in this layer
            elif (layer_idx == layer or layer_idx is None) and ablation_type == "k_replace":
                if head_idx:
                    k_replace_list.append(head_idx)
                else:
                    k_replace_list.extend(range(num_heads))
            elif (layer_idx == layer or layer_idx is None) and ablation_type == "v_replace":
                if head_idx:
                    v_replace_list.append(head_idx)
                else:
                    v_replace_list.extend(range(num_heads))
            elif layer_idx == layer and "extract" in ablation_type:
                if ablation_type == 'q_extract':
                    q_extract_list.append(head_idx)
                elif ablation_type == 'k_extract':
                    k_extract_list.append(head_idx)
                elif ablation_type == 'v_extract':
                    v_extract_list.append(head_idx)

        # Calculate q, k, v projections corresponding to each head and token type
        for head in range(num_heads):
            for tp in range(model_wrapper.TOKEN_TYPES):
                query_states[tp][:, :, head, :] = self.q_proj(hidden_states[head][tp]).view(batch_size, seq_len,
                                                                                            num_heads,
                                                                                            self.head_dim)[:, :, head,
                :]  # Shape: (batch_size, seq_len, num_heads, head_dim)

                # Q Replacement Ablation
                if head in q_replace_list and not model_wrapper.is_corrupted_run and not model_wrapper.is_original_run:
                    corrupted_tensor = model_wrapper.corrupted_activations["input_to_attn_per_type"][layer].to(
                        query_states[tp].device)

                    # Note: corrupted seq_len may be different from clean seq_len due to different tokenization
                    corrupted_seq_len = corrupted_tensor.shape[1]

                    # Reshape after q projection
                    reshaped_tensor = self.q_proj(corrupted_tensor).view(batch_size, corrupted_seq_len, num_heads,
                                                                         self.head_dim)

                    # Replace the Q of the last token in the sequence
                    query_states[tp][:, -1, head, :] = reshaped_tensor[:, -1, head, :]

                # grouped value shared by multiple attention heads
                key_states[tp][:, :, head, :] = self.k_proj(hidden_states[head][tp]).view(batch_size, seq_len,
                                                                                          num_key_value_heads,
                                                                                          self.head_dim)[:, :,
                head // self.num_key_value_groups, :]
                # K Replacement Ablation
                if (
                        head in k_replace_list and not model_wrapper.is_corrupted_run and not model_wrapper.is_original_run and
                        tp == model_wrapper.TOKEN_TYPES - 1):
                    # Replace the key on different examples with corrupted activations
                    hidden_states_sub = model_wrapper.corrupted_activations["input_to_attn_per_type"][layer].to(
                        key_states[tp].device, dtype=key_states[tp].dtype)
                    k_sub = self.k_proj(hidden_states_sub).view(batch_size, seq_len, num_key_value_heads,
                                                                self.head_dim)[:, :, head // self.num_key_value_groups,
                    :]  # (B,L,D)

                    tp_pos = torch.isin(model_wrapper.tp_inds, model_wrapper.kv_ablation_tp)
                    key_states_view = key_states[tp].select(2, head)  # Shape: (B,L,D)
                    # Key substitute must have the same number of tokens as the original key
                    key_states_view[tp_pos] = k_sub[tp_pos]

                    # Keep original Q unchanged for K ablation
                    original_tensor = model_wrapper.original_activations["input_to_attn_per_type"][layer].to(
                        query_states[tp].device, dtype=query_states[tp].dtype)

                    original_seq_len = original_tensor.shape[1]

                    # Reshape after q projection
                    reshaped_tensor = self.q_proj(original_tensor).view(batch_size, original_seq_len, num_heads,
                                                                        self.head_dim)

                    # Replace the Q of the last token in the sequence
                    query_states[tp][:, -1, head, :] = reshaped_tensor[:, -1, head, :]

                # grouped value shared by multiple attention heads
                value_states[tp][:, :, head, :] = self.v_proj(hidden_states[head][tp]).view(batch_size, seq_len,
                                                                                            num_key_value_heads,
                                                                                            self.head_dim)[:, :,
                head // self.num_key_value_groups, :]
                # V Replacement Ablation
                if (
                        head in v_replace_list and not model_wrapper.is_corrupted_run and not model_wrapper.is_original_run and
                        tp == model_wrapper.TOKEN_TYPES - 1):
                    # Replace the value on different examples with corrupted activations
                    # Key&Value substitute must have the same number of tokens as the original one
                    hidden_states_sub = model_wrapper.corrupted_activations["input_to_attn_per_type"][layer].to(
                        value_states[tp].device, dtype=value_states[tp].dtype)
                    v_sub = self.v_proj(hidden_states_sub).view(batch_size, seq_len, num_key_value_heads,
                                                                self.head_dim)[:, :, head // self.num_key_value_groups,
                    :]  # (B,L,D)

                    tp_pos = torch.isin(model_wrapper.tp_inds, model_wrapper.kv_ablation_tp)
                    value_states_view = value_states[tp].select(2, head)  # Shape: (B,L,D)
                    value_states_view[tp_pos] = v_sub[tp_pos]

                    # Keep original Q&K unchanged for V ablation
                    original_tensor = model_wrapper.original_activations["input_to_attn_per_type"][layer].to(
                        query_states[tp].device, dtype=query_states[tp].dtype)
                    original_seq_len = original_tensor.shape[1]
                    # Reshape after q projection
                    reshaped_tensor = self.q_proj(original_tensor).view(batch_size, original_seq_len, num_heads,
                                                                        self.head_dim)
                    # Replace the Q of the last token in the sequence
                    query_states[tp][:, -1, head, :] = reshaped_tensor[:, -1, head, :]
                    # grouped value shared by multiple attention heads
                    k_sub = self.k_proj(original_tensor).view(batch_size, seq_len, num_key_value_heads,
                                                              self.head_dim)[:, :, head // self.num_key_value_groups,
                    :]

                    key_states[tp][:, -1, head, :] = k_sub[:, -1, :]

            # QKV Extraction
            if not model_wrapper.is_corrupted_run and not model_wrapper.is_original_run:
                if head in q_extract_list:
                    model_wrapper.q_vectors.append(query_states[model_wrapper.TOKEN_TYPES - 1][:,
                    -1, head, :].detach().cpu().clone().squeeze(
                        0))  # Shape: (head_dim, )
                if head in k_extract_list:
                    model_wrapper.k_vectors.append(key_states[model_wrapper.TOKEN_TYPES - 1][:,
                    :, head,
                    :].detach().cpu().clone().squeeze(0))  # Shape: (seq_len, head_dim)
                if head in v_extract_list:
                    model_wrapper.v_vectors.append(value_states[model_wrapper.TOKEN_TYPES - 1][:,
                    :, head,
                    :].detach().cpu().clone().squeeze(0))  # Shape: (seq_len, head_dim)

        # Stack all token types together
        # Shape: (batch_size * TOKEN_TYPES, num_heads, seq_len, head_dim)
        query_states = torch.vstack([query_states[tp] for tp in range(model_wrapper.TOKEN_TYPES)]).transpose(1, 2)
        key_states = torch.vstack([key_states[tp] for tp in range(model_wrapper.TOKEN_TYPES)]).transpose(1, 2)
        value_states = torch.vstack([value_states[tp] for tp in range(model_wrapper.TOKEN_TYPES)]).transpose(1, 2)

        if query_states.device != key_states.device:
            key_states = key_states.to(query_states.device)

        # Extend position_ids to match the batch size
        if position_ids.shape[0] == original_batch_size:
            position_ids = position_ids.expand(model_wrapper.TOKEN_TYPES, original_batch_size, seq_len).reshape(
                extended_batch_size,
                seq_len)  # Shape: (batch_size * TOKEN_TYPES, seq_len)

        # Calculate the rotary embeddings
        cos, sin = rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # Update the cache if past_key_value is provided
        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "sliding_window": self.sliding_window,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # Calculate attention weights
        # Shape: (batch_size * TOKEN_TYPES, num_heads, seq_len, seq_len)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

        if self.config.attn_logit_softcapping is not None:
            attn_weights = attn_weights / self.config.attn_logit_softcapping
            attn_weights = torch.tanh(attn_weights)
            attn_weights = attn_weights * self.config.attn_logit_softcapping

        if attention_mask is not None:
            # Shape: (batch_size * TOKEN_TYPES, 1, seq_len, seq_len)
            attention_mask = attention_mask.expand(model_wrapper.TOKEN_TYPES, original_batch_size, 1, seq_len,
                                                   attn_weights.shape[-1]).reshape(extended_batch_size, 1, seq_len,
                                                                                   attn_weights.shape[-1])
            causal_mask = attention_mask[:, :, :,
            : key_states.shape[-2]]  # Shape: (batch_size * TOKEN_TYPES, 1, seq_len, seq_len)
            attn_weights = attn_weights + causal_mask

        # Calculate attention score
        # Upcast to float32 for numerical stability
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        if not model_wrapper.is_corrupted_run and not model_wrapper.is_original_run:
            # In corrupted run, we do not store any attention scores
            # Aggregate attention scores on token types
            attn_scores = attn_weights.clone().detach()[-1, :, -1, :]  # Shape: (num_heads, seq_len)
            attn_scores_col = []
            for tp in range(model_wrapper.TOKEN_TYPES):
                attn_scores_col.append(attn_scores[:, model_wrapper.tp_inds[-1] == tp].sum(dim=-1, keepdim=True).cpu())
            attn_score_dict = model_wrapper.attention_scores[-1]
            attn_scores_tp = torch.cat(attn_scores_col, dim=1)
            attn_score_dict[layer] = attn_scores_tp

        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights,
                                   value_states)  # Shape: (batch_size * TOKEN_TYPES, num_heads, seq_len, head_dim)

        if attn_output.size() != (extended_batch_size, num_heads, seq_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(extended_batch_size, num_heads, seq_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        # Shape: (batch_size * TOKEN_TYPES, seq_len, num_heads, head_dim)
        attn_output = attn_output.transpose(1, 2).contiguous()

        all_heads_outputs = []
        layer_activations = torch.zeros(
            original_batch_size,
            seq_len,
            num_heads * self.head_dim,
            device=attn_output.device,
            dtype=attn_output.dtype
        )  # Shape: (batch_size, seq_len, num_heads * head_dim)

        for current_head in range(num_heads):
            w_o_slice = self.o_proj.weight.T[current_head * self.head_dim: (current_head + 1) * self.head_dim, :]
            if past_key_value is not None and seq_len == 1:
                target_tp = model_wrapper.TOKEN_TYPES - 1  # Use the last token type for past key value
                all_heads_outputs.append(
                    attn_output[target_tp * original_batch_size: (target_tp + 1) * original_batch_size, :, current_head,
                    :].to(
                        self.o_proj.weight.device) @
                    self.o_proj.weight.T[current_head * self.head_dim: (current_head + 1) * self.head_dim, :])
            else:
                # For each head, create an empty tensor to store the output
                # Shape: (batch_size, seq_len, hidden_size)
                all_heads_outputs.append(torch.empty(original_batch_size, seq_len, self.o_proj.weight.shape[0],
                                                     device=attn_output.device, dtype=attn_output.dtype))

                for tp in range(model_wrapper.TOKEN_TYPES):
                    # Current token type positions
                    tp_pos = model_wrapper.tp_inds == tp  # Shape: (batch_size, seq_len)
                    head_attention = attn_output[tp * original_batch_size: (tp + 1) * original_batch_size, :,
                    current_head, :].to(
                        self.o_proj.weight.device)

                    # Slice the head position
                    head_slice = slice(current_head * self.head_dim, (current_head + 1) * self.head_dim)

                    # Store attention activations
                    mask = tp_pos.unsqueeze(-1).expand(-1, -1, self.head_dim)  # (batch_size, seq_len, head_dim)
                    # Use slice to create a view
                    target_activations = layer_activations[..., head_slice]
                    # Use masked_scatter to assign values in place
                    target_activations.masked_scatter_(mask, head_attention[mask])

                    # Assign the output of the attention to the corresponding token type positions
                    all_heads_outputs[-1][model_wrapper.tp_inds == tp] = \
                        (head_attention @ w_o_slice)[model_wrapper.tp_inds == tp]

        assert len(all_heads_outputs) == num_heads
        all_heads_outputs = sum(all_heads_outputs)  # Shape: (batch_size, seq_len, hidden_size)

        if not output_attentions:
            attn_weights = None

        # Assign the attention activations to the attention head activations cache
        if not model_wrapper.is_corrupted_run and not model_wrapper.is_original_run:
            model_wrapper.attention_head_activations[
                layer] = layer_activations.detach()  # Shape: (batch_size=1, seq_len, n_heads * head_dim)

        return all_heads_outputs, attn_weights, past_key_value

    @staticmethod
    def decoder_forward(
            hidden_states: torch.Tensor,  # Shape: (batch_size, seq_len, head_dim)
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Any] = None,
            output_attentions: Optional[bool] = False,
            use_cache: Optional[bool] = False,
            cache_position: Optional[torch.LongTensor] = None,
            self=None,  # corresponds to model.layers[layer]
            layer=None,  # corrupted activation position
            model_wrapper=None,
            **kwargs
    ):
        num_heads = model_wrapper.model_config["n_heads"]
        sliding_window = self.config.sliding_window

        # Sliding window attention in efficient SDPA
        # Literally will not be used
        # attention_mask Shape: (batch_size, 1, seq_len, seq_len)
        if sliding_window is not None and attention_mask is not None:
            assert not self.config._attn_implementation == "flash_attention_2"

            min_dtype = torch.finfo(hidden_states.dtype).min
            sliding_window_mask = torch.tril(
                torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-sliding_window
            )
            attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
            if attention_mask.shape[-1] <= 1:  # when decoding
                attention_mask = attention_mask[:, :, :, -sliding_window:]

        # Keep the original hidden states for residual connection
        residual = hidden_states

        # hidden_states_per_head_and_type Shape: (num_heads, token_types, batch_size, seq_len, head_dim)
        if past_key_value is not None and hidden_states.shape[1] == 1:
            hidden_states_per_head = {head: hidden_states.clone() for head in range(num_heads)}
            hidden_states_per_head_and_type = {
                head: {tp: hidden_states_per_head[head].clone() for tp in range(model_wrapper.TOKEN_TYPES)} for head
                in range(num_heads)}
        else:
            hidden_states_per_head = {head: hidden_states.clone() for head in range(num_heads)}
            hidden_states_per_head_and_type = {
                head: {tp: hidden_states_per_head[head].clone() for tp in range(model_wrapper.TOKEN_TYPES)} for head
                in range(num_heads)}

            # Make sure the corrupted activations has the same sequence length as the clean ones
            if model_wrapper.is_corrupted_run:
                model_wrapper.corrupted_activations["input_to_attn_per_type"][layer] = \
                    self.input_layernorm(hidden_states).detach().clone().cpu()
                # Apply the same ablation for corruption if needed
                if model_wrapper.corruption_with_ablation:
                    for (from_tp, to_tp, ablation_type) in model_wrapper.corruption_ablate_edges_map:
                        if ablation_type == 'zero':
                            B, _, L, _ = attention_mask.shape  # Shape: (batch_size, 1, seq_len, seq_len)

                            # mask_q[b,i,:] = True
                            mask_q = (model_wrapper.tp_inds == to_tp).unsqueeze(2).expand(B, L, L)
                            # mask_k[b,:,j] = True
                            mask_k = (model_wrapper.tp_inds == from_tp).unsqueeze(1).expand(B, L, L)
                            # zero_positions[b,i,j] = True
                            zero_positions = mask_q & mask_k  # (B,L,L)
                            # Fill in the attention mask with -inf for the zero positions
                            attention_mask = attention_mask.masked_fill(
                                zero_positions.unsqueeze(1),  # Extend to Shape: (B,1,L,L)
                                -torch.inf
                            )
            else:
                if model_wrapper.is_original_run:
                    model_wrapper.original_activations["input_to_attn_per_type"][layer] = \
                        self.input_layernorm(hidden_states).detach().clone().cpu()
                # Traverse the ablation edges map and apply the ablation
                for (from_tp, to_tp, ablation_type) in model_wrapper.ablate_edges_map:
                    if ablation_type == "corrupted":
                        for head in range(num_heads):
                            # Assign the corrupted activations to the corresponding from-to token types
                            hidden_states_per_head_and_type[head][to_tp][model_wrapper.tp_inds == from_tp] = \
                                (model_wrapper.corrupted_activations["input_to_attn_per_type"][layer]).to(
                                    hidden_states_per_head_and_type[head][from_tp].device)[
                                    model_wrapper.tp_inds == from_tp].detach()
                    elif ablation_type == "zero":
                        B, _, L, _ = attention_mask.shape  # Shape: (batch_size, 1, seq_len, seq_len)

                        # mask_q[b,i,:] = True
                        mask_q = (model_wrapper.tp_inds == to_tp).unsqueeze(2).expand(B, L, L)
                        # mask_k[b,:,j] = True
                        mask_k = (model_wrapper.tp_inds == from_tp).unsqueeze(1).expand(B, L, L)
                        # zero_positions[b,i,j] = True
                        zero_positions = mask_q & mask_k  # (B,L,L)

                        # Fill in the attention mask with -inf for the zero positions
                        attention_mask = attention_mask.masked_fill(
                            zero_positions.unsqueeze(1),  # Extend to Shape: (B,1,L,L)
                            -torch.inf
                        )

        # NOTE: LayerNorm before self-attention!!
        hidden_states_per_head_and_type = {head: {
            tp: self.input_layernorm(hidden_states_per_head_and_type[head][tp])
            for tp in range(model_wrapper.TOKEN_TYPES)
        }
            for head in range(num_heads)
        }  # Shape: (num_heads, token_types, batch_size, seq_len, head_dim)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states_per_head_and_type,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
        )  # hidden_states Shape: (batch_size, seq_len, head_dim)

        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states.clone()
        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)

        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs