import torch

def compute_importance_score_inside_attn(grad, layer, head, to_tp, gemma_model):
    assert not gemma_model.generate_mode, "why are you in attn backward hook in generate mode??"
    bsz, q_len, hidden_dim = grad.shape[0], grad.shape[1], grad.shape[2]
    assert len(grad.shape) == 3
    for tp in range(gemma_model.task.TOKEN_TYPES):
        if (gemma_model.important_edges is None) or gemma_model.important_edges["inside_attn"][layer, head, tp, to_tp]:
            total_tokens_of_type_tp = (gemma_model.tp_inds == tp).sum()
            tp_per_batch_cumsum = torch.cumsum(torch.tensor([0] + [(gemma_model.tp_inds[batch_elem] == tp).sum() for batch_elem in range(bsz)]), dim=0)
            assert tp_per_batch_cumsum[-1] == total_tokens_of_type_tp
            importance_matrix = ((-gemma_model.corrupted_minus_original_activations["input_to_attn_per_type"][layer][head][to_tp][gemma_model.tp_inds == tp].detach().view(total_tokens_of_type_tp, hidden_dim).to(grad.device)) @ 
                                    grad[gemma_model.tp_inds == tp].detach().view(total_tokens_of_type_tp, hidden_dim).transpose(-1, -2))
            assert importance_matrix.shape == (total_tokens_of_type_tp, total_tokens_of_type_tp), importance_matrix.shape
            importance_score = importance_matrix.diagonal(offset=0, dim1=-2, dim2=-1)
            assert importance_score.shape == (total_tokens_of_type_tp, )
            importance_score_per_batch_elem = torch.tensor([importance_score[tp_per_batch_cumsum[i]:tp_per_batch_cumsum[i + 1]].sum(dim=-1).item() for i in range(len(tp_per_batch_cumsum) - 1)])
            assert importance_score_per_batch_elem.shape == (bsz,)
            gemma_model.importance_scores["inside_attn"][layer, head, tp, to_tp] += torch.abs(importance_score_per_batch_elem).sum(dim=-1).item()
    for l in range(len(gemma_model.model.model.layers) - 1, layer + 1, -1):
        for p in gemma_model.model.model.layers[l].parameters():
            p.grad = None
    torch.cuda.empty_cache()
