import os
from typing import Dict, Any, Optional, Tuple, Union, List
import functools
import math
import warnings
from pathlib import Path
import json

import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.models.gemma2.modeling_gemma2 import apply_rotary_pos_emb, repeat_kv, \
    BaseModelOutputWithPast, HybridCache

from patching_gemma import logger
from patching_gemma.models.gemma2 import Gemma2Model
from patching_gemma.models.utils.importance_scores.importance_scores_calculation import compute_importance_score_inside_attn
from patching_gemma.models.utils.pruning.prune_heads import prune_heads
from patching_gemma.models.utils.pruning.prune_edges_inside_attn import prune_edges_inside_attn
from patching_gemma.models.utils.data_processing.collate_function import collate_fn_everything
from patching_gemma.models.utils.data_processing.dataset import RequestDataset

class Gemma2PruneEdgesInCircuits(Gemma2Model):
    def run(self, task, limit, batch_size, prune_using_imp_scores, prune_k,
                affect_whom, log_dir) -> None:
        edges_between_types_to_prune_for_sure = affect_whom
        logger.info(f"Started run method for task {task.config.name}")
        edges_between_types_to_prune_for_sure = [tuple(e) for e in edges_between_types_to_prune_for_sure]
        self.model_logs["edges_between_types_to_prune_for_sure"] = edges_between_types_to_prune_for_sure
        self.edges_between_types_to_prune_for_sure = edges_between_types_to_prune_for_sure
        edges_to_prune_for_sure = [
                (layer, head, tp1, tp2)
                for layer in range(len(self.model.model.layers))
                for head in range(self.model.model.layers[0].self_attn.num_heads)
                for tp1, tp2 in edges_between_types_to_prune_for_sure
        ]
        self.model_logs["edges_to_prune_for_sure"] = edges_to_prune_for_sure
        self.task = task
        assert task.can_be_token_separable

        edge_paths, edge_prune_k = [] if prune_using_imp_scores is not None else None, [] if prune_using_imp_scores is not None else None
        node_paths, node_prune_k = [] if prune_using_imp_scores is not None else None, [] if prune_using_imp_scores is not None else None
        if prune_using_imp_scores is not None:
            assert len(prune_using_imp_scores) == len(prune_k)
            for path, p_k in zip(prune_using_imp_scores, prune_k):
                if Path(path).joinpath("attn_scores.npy").exists():
                    node_paths.append(path)
                    node_prune_k.append(p_k)
                elif (Path(path).joinpath("inside_attn.npy").exists()):
                    edge_paths.append(path)
                    edge_prune_k.append(p_k)
                else:
                    logger.debug("path for pruning", path)
                    raise ValueError()

        dataset = RequestDataset(task, limit, corrupted=True, tokenizer=self.tokenizer)
        self.model_logs["first_3_dataset_examples"] = [dataset[i] for i in range(3)]
        self.num_requests = len(dataset)

        self.original_activations = {
            "attn": {layer: {head: None for head in range(self.model.model.layers[0].self_attn.num_heads)}
                     for layer in range(len(self.model.model.layers))},
            "input_to_attn_per_type": {layer: {head: {tp: None
                                for tp in range(self.task.TOKEN_TYPES)}
                                for head in range(self.model.model.layers[0].self_attn.num_heads)}
                                for layer in range(len(self.model.model.layers))},
        }

        self.corrupted_minus_original_activations = {
            "input_to_attn_per_type": {layer: {head: {tp: None
                                for tp in range(self.task.TOKEN_TYPES)}
                                for head in range(self.model.model.layers[0].self_attn.num_heads)}
                                for layer in range(len(self.model.model.layers))},
        }

        self.corrupted_activations = {
            "attn": {layer: {head: None for head in range(self.model.model.layers[0].self_attn.num_heads)}
                     for layer in range(len(self.model.model.layers))},
            "input_to_attn_per_type": {layer: {head: {tp: None
                                for tp in range(self.task.TOKEN_TYPES)}
                                for head in range(self.model.model.layers[0].self_attn.num_heads)}
                                for layer in range(len(self.model.model.layers))},
        }

        self.importance_scores = {
            "inside_attn": torch.zeros(
                                len(self.model.model.layers),
                                self.model.model.layers[0].self_attn.num_heads,
                                task.TOKEN_TYPES,
                                task.TOKEN_TYPES
                            )
        }

        self.important_components_by_layer = prune_heads(node_paths, node_prune_k,
                                                         len(self.model.model.layers), self.model.model.layers[0].self_attn.num_heads, self.task.TOKEN_TYPES,
                                                         log_dir, is_split_by_token_type=True)
        self.important_edges = prune_edges_inside_attn(edge_paths, edge_prune_k,
                                           edges_to_prune_for_sure,
                                           len(self.model.model.layers), self.model.model.layers[0].self_attn.num_heads, self.task.TOKEN_TYPES,
                                           log_dir, is_split_by_token_type=True)
        if self.important_components_by_layer is not None:
            self.model_logs["number_of_important_heads"] = sum([len(self.important_components_by_layer[layer]["attn"]) for layer in self.important_components_by_layer])
        if self.important_edges is not None:
            self.model_logs["number_of_important_edges_inside_attn"] = self.important_edges["inside_attn"].sum().item()

        self.save_importance_scores(task.loss, dataset, batch_size, log_dir)
        self.generate(dataset, batch_size)

    def save_importance_scores(self, loss_function, dataset, batch_size, log_dir) -> None:
        self.generate_mode = False
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                                collate_fn=functools.partial(collate_fn_everything,
                                                             padding_left=False, use_corrupted_activations=True,
                                                             tokenizer=self.tokenizer),
                                num_workers=len(os.sched_getaffinity(0)) - 1)
        average_loss = 0
        examples = []
        items_passed = 0
        num_predictive_inds_total = 0

        for batch in dataloader:
            inputs = batch[5].to("cuda") # TODO: change to handle different model device
            target_ids = batch[7].to("cuda") # TODO: change to handle different model device
            lens = batch[2]
            if len(examples) < 3:
                for i in range(inputs["input_ids"].shape[0]):
                    if len(examples) < 3:
                        examples.append(self.tokenizer.decode(inputs["input_ids"][i].detach().cpu()))

            corrupted_contexts = batch[6].to("cuda") # TODO: change to handle different model device

            # TODO: averaging only over non-padding tokens (need to think how to do it in omp scores compute too)
            # For token types computation
            self.tp_inds = self.task.get_token_types_for_contexts_with_targets(self.tokenizer, inputs["input_ids"].detach().cpu())
            self.corrupted_tp_inds = self.task.get_token_types_for_contexts_with_targets(self.tokenizer, corrupted_contexts["input_ids"].detach().cpu())
            assert (self.corrupted_tp_inds == self.tp_inds).all()

            self.is_corrupted_run = True
            with torch.no_grad():
                self.model(**corrupted_contexts, use_cache=False)

            for p in self.model.parameters():
                assert p.grad is None

            self.is_corrupted_run = False
            logits = self.model(**inputs, use_cache=False).logits
            loss = loss_function(logits, target_ids.to(logits.device), lens, self.tp_inds)
            loss.backward()

            for p in self.model.parameters():
                p.grad = None

            if "run_logs" not in self.model_logs:
                self.model_logs["run_logs"] = []

            predictive_inds = ((self.tp_inds == self.task.TARGET_TYPE) | (self.tp_inds == self.task.LAST_SEP_TYPE))
            predictive_inds[torch.arange(predictive_inds.shape[0]), (predictive_inds.cumsum(dim=-1) * predictive_inds).argmax(dim=-1)] = 0
            num_predictive_inds_total += torch.count_nonzero(predictive_inds).item()
            self.model_logs["run_logs"].append({
                "batch_len": len(batch), 
                "inputs_shape": inputs["input_ids"].shape,
                "target_shape": target_ids.shape,
                "last_3_queries": [self.tokenizer.decode(inputs["input_ids"][-i]) for i in range(min(inputs["input_ids"].shape[0], 3), 0, -1)],
                "last_1_corrupted_contexts_tokenized": [self.tokenizer.convert_ids_to_tokens(corrupted_contexts["input_ids"][-i]) for i in range(min(corrupted_contexts["input_ids"].shape[0], 1), 0, -1)],
                "last_1_queries_tokenized": [self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][-i]) for i in range(min(inputs["input_ids"].shape[0], 1), 0, -1)],
                "last_1_tp_inds": [self.tp_inds[-i].tolist() for i in range(min(inputs["input_ids"].shape[0], 1), 0, -1)],
                "last_5_lens": lens.tolist()[-min(inputs["input_ids"].shape[0], 5):],
                "last_5_predictive_inds_in_input": self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][predictive_inds].detach().cpu())[-min(inputs["input_ids"].shape[0], 5):],
                "last_5_correct_targets": self.tokenizer.convert_ids_to_tokens(target_ids.detach().cpu())[-min(inputs["input_ids"].shape[0], 5):],
                "last_5_correct_logits_probs": torch.nn.functional.softmax(logits.detach().cpu(), dim=-1)[predictive_inds, target_ids.detach().cpu()].tolist()[-min(inputs["input_ids"].shape[0], 5):],
                "shapes_of_logits_and_targets": (logits[predictive_inds, :].shape, target_ids.shape),
                "loss": loss.item()
            })
            average_loss += loss.item() * torch.count_nonzero(predictive_inds).item()
            items_passed += inputs["input_ids"].shape[0]

            del inputs
            del target_ids
            del logits
            del loss
            torch.cuda.empty_cache()

        self.model_logs["first_3_loader_no_generate_exampels"] = examples
        self.model_logs["loss"] = average_loss / num_predictive_inds_total
        self._save_imp_scores(log_dir, len(dataset))

    def generate(self, dataset, batch_size) -> None:
        self.generate_mode = True
        logger.debug("Start generate part")
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                                collate_fn=functools.partial(collate_fn_everything,
                                                             padding_left=True, use_corrupted_activations=True,
                                                             tokenizer=self.tokenizer),
                                num_workers=len(os.sched_getaffinity(0)) - 1)
        sum_accuracy = 0
        continuations = []
        examples = []
        is_correct = []
        all_targets = []

        for batch in dataloader:
            inputs = batch[0].to("cuda") # TODO: change to handle different model device
            targets = batch[3]
            if len(examples) < 3:
                for i in range(inputs["input_ids"].shape[0]):
                    if len(examples) < 3:
                        examples.append(self.tokenizer.decode(inputs["input_ids"][i].detach().cpu()))

            corrupted_contexts = batch[4].to("cuda") # TODO: change to handle different model device

            self.tp_inds = self.task.get_token_types_for_contexts(self.tokenizer, inputs["input_ids"].detach().cpu())
            self.corrupted_tp_inds = self.task.get_token_types_for_contexts(self.tokenizer, corrupted_contexts["input_ids"].detach().cpu())
            assert (self.corrupted_tp_inds == self.tp_inds).all()

            self.is_corrupted_run = True
            with torch.no_grad():
                self.model(**corrupted_contexts, use_cache=False)

            self.is_corrupted_run = False

            num_key_value_heads = self.model.config.num_key_value_heads
            self.model.config.num_key_value_heads = self.model.config.num_attention_heads
            max_generated_length = inputs["input_ids"].shape[1] + 7
            past_key_values = HybridCache(config=self.model.config, max_batch_size=inputs["input_ids"].shape[0] * self.task.TOKEN_TYPES,
                                          max_cache_len=max_generated_length, device=self.model.model.device, dtype=self.model.model.dtype)
            out = self.model.generate(**inputs, past_key_values=past_key_values, cache_implementation=None, max_new_tokens=7)
            self.model.config.num_key_value_heads = num_key_value_heads

            for i in range(out.shape[0]):
                continuation = self.tokenizer.decode(out[i][inputs["input_ids"][i].shape[0]:])

                continuations.append(continuation)
                is_correct.append(int(continuation.strip().startswith(targets[i])))
                all_targets.append(targets[i])
                sum_accuracy += int(continuation.strip().startswith(targets[i]))

            del inputs
            torch.cuda.empty_cache()

        self.model_logs["continuations"] = continuations
        self.model_logs["is_correct"] = is_correct
        self.model_logs["targets"] = all_targets
        self.model_logs["first_3_loader_generate_exampels"] = examples
        self.model_logs["accuracy"] = sum_accuracy / self.num_requests

    def break_into(self) -> None:
        self.hook_handles = []
        self.prev_forwards = {"decoder": [], "attn": []}

        for layer in range(len(self.model.model.layers)):
            self.prev_forwards["attn"].append(self.model.model.layers[layer].self_attn.forward)
            forward_partial = functools.partial(self.attn_forward, layer=layer,
                                                self=self.model.model.layers[layer].self_attn,
                                                gemma_model=self)
            self.model.model.layers[layer].self_attn.forward = forward_partial

            self.prev_forwards["decoder"].append(self.model.model.layers[layer].forward)
            forward_partial = functools.partial(self.decoder_forward, layer=layer,
                                                self=self.model.model.layers[layer],
                                                gemma_model=self)
            self.model.model.layers[layer].forward = forward_partial

    def break_out(self) -> None:
        for layer, f in enumerate(self.prev_forwards):
            forward_partial = functools.partial(self.prev_forwards["attn"][layer],
                                                self=self.model.model.layers[layer].self_attn)
            self.model.model.layers[layer].self_attn.forward = forward_partial
            forward_partial = functools.partial(self.prev_forwards["decoder"][layer],
                                                self=self.model.model.layers[layer])
            self.model.model.layers[layer].forward = forward_partial
        for h in self.hook_handles:
            h.remove()


    def _save_imp_scores(self, log_dir, num_requests_for_imp_scores_averaging) -> None:
        log_dir.mkdir(parents=True, exist_ok=True)
        self.importance_scores["inside_attn"] = self.importance_scores["inside_attn"] / num_requests_for_imp_scores_averaging
        np.save(log_dir.joinpath("inside_attn.npy"), self.importance_scores["inside_attn"].numpy())
        logger.info(f"Saved {1} file to {str(log_dir)}")

    @staticmethod
    def attn_forward(
        self,
        hidden_states: dict[int, torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Any] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        layer: Optional[int] = None,
        gemma_model: Optional[Any] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states[0][0].size()
        extended_bsz = bsz * gemma_model.task.TOKEN_TYPES
        original_bsz = bsz

        query_states = {tp: torch.empty(bsz, q_len, self.num_heads, self.head_dim,
                                        device=hidden_states[self.num_heads - 1][gemma_model.task.TOKEN_TYPES - 1].device, dtype=hidden_states[self.num_heads - 1][gemma_model.task.TOKEN_TYPES - 1].dtype)
                        for tp in range(gemma_model.task.TOKEN_TYPES)}
        key_states = {tp: torch.empty(bsz, q_len, self.num_heads, self.head_dim,
                                    device=hidden_states[self.num_heads - 1][gemma_model.task.TOKEN_TYPES - 1].device, dtype=hidden_states[self.num_heads - 1][gemma_model.task.TOKEN_TYPES - 1].dtype)
                        for tp in range(gemma_model.task.TOKEN_TYPES)}
        value_states = {tp: torch.empty(bsz, q_len, self.num_heads, self.head_dim,
                                        device=hidden_states[self.num_heads - 1][gemma_model.task.TOKEN_TYPES - 1].device, dtype=hidden_states[self.num_heads - 1][gemma_model.task.TOKEN_TYPES - 1].dtype)
                        for tp in range(gemma_model.task.TOKEN_TYPES)}

        for head in range(self.num_heads):
            for tp in range(gemma_model.task.TOKEN_TYPES):
                query_states[tp][:, :, head, :] = self.q_proj(hidden_states[head][tp]).view(bsz, q_len, self.num_heads, self.head_dim)[:, :, head, :]
                key_states[tp][:, :, head, :] = self.k_proj(hidden_states[head][tp]).view(bsz, q_len, self.num_key_value_heads, self.head_dim)[:, :, head // self.num_key_value_groups, :]
                value_states[tp][:, :, head, :] = self.v_proj(hidden_states[head][tp]).view(bsz, q_len, self.num_key_value_heads, self.head_dim)[:, :, head // self.num_key_value_groups, :]

        query_states = torch.vstack([query_states[tp] for tp in range(gemma_model.task.TOKEN_TYPES)]).transpose(1, 2)
        key_states = torch.vstack([key_states[tp] for tp in range(gemma_model.task.TOKEN_TYPES)]).transpose(1, 2)
        value_states = torch.vstack([value_states[tp] for tp in range(gemma_model.task.TOKEN_TYPES)]).transpose(1, 2)

        if query_states.device != key_states.device:
            key_states = key_states.to(query_states.device)

        bsz = bsz * gemma_model.task.TOKEN_TYPES

        if position_ids.shape[0] == original_bsz:
            position_ids = position_ids.expand(gemma_model.task.TOKEN_TYPES, original_bsz, q_len).reshape(extended_bsz, q_len)
        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "sliding_window": self.sliding_window,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)


        # The following lines are deleted because now we do that in the beginning because of how interventions are designed
        # key_states = repeat_kv(key_states, self.num_key_value_groups)
        # value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

        if self.config.attn_logit_softcapping is not None:
            attn_weights = attn_weights / self.config.attn_logit_softcapping
            attn_weights = torch.tanh(attn_weights)
            attn_weights = attn_weights * self.config.attn_logit_softcapping
        if attention_mask is not None:  # no matter the length, we just slice it
            attention_mask = attention_mask.expand(gemma_model.task.TOKEN_TYPES, original_bsz, 1, q_len, attn_weights.shape[-1]).reshape(extended_bsz, 1, q_len, attn_weights.shape[-1])
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        # The following two lines are deleted because now we do it for each head separately
        # attn_output = attn_output.view(bsz, q_len, -1)
        # attn_output = self.o_proj(attn_output)

        head_outputs_together = []
        for cur_head in range(self.num_heads):
            if gemma_model.generate_mode and past_key_value is not None and q_len == 1:
                target_tp = gemma_model.task.TARGET_TYPE
                head_outputs_together.append(attn_output[target_tp * original_bsz : (target_tp + 1) * original_bsz, :, cur_head, :].to(self.o_proj.weight.device) @
                         self.o_proj.weight.T[cur_head * self.head_dim : (cur_head + 1) * self.head_dim, :])
            else:
                head_outputs_together.append(torch.empty(original_bsz, q_len, self.o_proj.weight.shape[0],
                                        device=attn_output.device, dtype=attn_output.dtype))
                for tp in range(gemma_model.task.TOKEN_TYPES):
                    if not gemma_model.is_corrupted_run and gemma_model.important_components_by_layer is not None and (cur_head, tp) not in gemma_model.important_components_by_layer[layer]["attn"]:
                        head_outputs_together[-1][gemma_model.tp_inds == tp] = \
                            gemma_model.corrupted_activations["attn"][layer][cur_head][gemma_model.tp_inds == tp].to(attn_output.device)
                    else:
                        head_outputs_together[-1][gemma_model.tp_inds == tp] = \
                            (attn_output[tp * original_bsz : (tp + 1) * original_bsz, :, cur_head, :].to(self.o_proj.weight.device) @
                            self.o_proj.weight.T[cur_head * self.head_dim : (cur_head + 1) * self.head_dim, :])[gemma_model.tp_inds == tp]
                if gemma_model.is_corrupted_run:
                    gemma_model.corrupted_activations["attn"][layer][cur_head] = head_outputs_together[-1].detach().clone().cpu()
                elif not gemma_model.generate_mode:
                    gemma_model.original_activations["attn"][layer][cur_head] = head_outputs_together[-1].detach().clone().cpu()

        assert len(head_outputs_together) == self.num_heads
        head_outputs_together = sum(head_outputs_together)


        if not output_attentions:
            attn_weights = None

        return head_outputs_together, attn_weights, past_key_value

    @staticmethod
    def decoder_forward(
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Any] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        self = None,
        layer = None,
        gemma_model = None,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        if layer % 10 == 0:
            logger.info(f"Decoder layer {layer}, forward pass")

        if self.is_sliding and attention_mask is not None:  # efficient SDPA and no padding
            assert not self.config._attn_implementation == "flash_attention_2"

            min_dtype = torch.finfo(hidden_states.dtype).min
            sliding_window_mask = torch.tril(
                torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
            )
            attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
            if attention_mask.shape[-1] <= 1:  # when decoding
                attention_mask = attention_mask[:, :, :, -self.sliding_window :]
        residual = hidden_states

        if gemma_model.generate_mode and past_key_value is not None and hidden_states.shape[1] == 1:
            hidden_states_per_head = {head: hidden_states.clone() for head in range(self.self_attn.num_heads)}
            hidden_states_per_head_and_type = {head: {tp: hidden_states_per_head[head].clone() for tp in range(gemma_model.task.TOKEN_TYPES)} for head in range(self.self_attn.num_heads)}
        else:
            hidden_states_per_head = {head: hidden_states.clone() for head in range(self.self_attn.num_heads)}
            hidden_states_per_head_and_type = {head: {tp: hidden_states_per_head[head].clone() for tp in range(gemma_model.task.TOKEN_TYPES)} for head in range(self.self_attn.num_heads)}
            
            if gemma_model.is_corrupted_run:
                for head in range(self.self_attn.num_heads):
                    for tp in range(gemma_model.task.TOKEN_TYPES):
                        gemma_model.corrupted_activations["input_to_attn_per_type"][layer][head][tp] = hidden_states_per_head_and_type[head][tp].detach().clone().cpu()
            else:
                for to_tp in range(gemma_model.task.TOKEN_TYPES):
                    for head in range(self.self_attn.num_heads):
                        if ((gemma_model.important_components_by_layer is None) or ((head, to_tp) in gemma_model.important_components_by_layer[layer]["attn"])):
                            for tp in range(gemma_model.task.TOKEN_TYPES):
                                if ((tp, to_tp) in gemma_model.edges_between_types_to_prune_for_sure or 
                                    ((gemma_model.important_edges is not None) and (
                                    (not gemma_model.important_edges["inside_attn"][layer, head, tp, to_tp])))):
                                    hidden_states_per_head_and_type[head][to_tp][gemma_model.tp_inds == tp] = \
                                        gemma_model.corrupted_activations["input_to_attn_per_type"][layer][head][to_tp][gemma_model.tp_inds == tp].to(
                                            hidden_states_per_head_and_type[head][tp].device).detach()
                        gemma_model.corrupted_minus_original_activations["input_to_attn_per_type"][layer][head][to_tp] = (
                            gemma_model.corrupted_activations["input_to_attn_per_type"][layer][head][to_tp].to(hidden_states_per_head_and_type[head][to_tp].device).detach() - hidden_states_per_head_and_type[head][to_tp]
                        ).cpu()
                        gemma_model.original_activations["input_to_attn_per_type"][layer][head][to_tp] = hidden_states_per_head_and_type[head][to_tp].clone().cpu()
             
                        if not gemma_model.generate_mode:
                            if (gemma_model.important_components_by_layer is None) or ((head, to_tp) in gemma_model.important_components_by_layer[layer]["attn"]):
                                hidden_states_per_head_and_type[head][to_tp].register_hook(functools.partial(compute_importance_score_inside_attn,
                                                                                                                layer=layer, head=head, to_tp=to_tp,
                                                                                                                llama_model=gemma_model))   
        hidden_states_per_head_and_type = {head: {
                                                    tp: self.input_layernorm(hidden_states_per_head_and_type[head][tp])
                                                    for tp in range(gemma_model.task.TOKEN_TYPES)
                                                }
                                                for head in range(self.self_attn.num_heads)
                                            }

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states_per_head_and_type,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states.clone()

        hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs