""" Attention utilities. """

from typing import Tuple, List
import torch

def concat_tensors(tensors: Tuple[torch.Tensor], dim:int=0) -> torch.Tensor:
    """
    Concatenates the attention and attention gradients along the specified dimension.

    Args:
        attention (torch.Tensor): The attention tensor.
        attention_grads (torch.Tensor): The attention gradients tensor.
        dim (int): The dimension along which to concatenate the attention and attention gradients.

    Returns:
        torch.Tensor: The concatenated attention and attention gradients tensor.
    """
    return torch.cat(tensors, dim=dim)


def get_modified_tensor(tensor: torch.Tensor, modified_indices: List[int]) -> torch.Tensor:
    """
    Generate a modified tensor by removing specific rows and columns from the input tensor.
    
    Args:
        tensor (torch.Tensor): The input tensor to be modified.
        modified_indices (List[int]): A list of indices to be modified from the tensor.
        
    Returns:
        torch.Tensor: The modified tensor with the specified rows and columns removed.
    """

    modified_tensor = torch.clone(tensor)

    # Check the shape of the attention tensor to be number_layers x number_heads x length_tokens x length_tokens
    tensor_shape = tensor.shape
    num_layers, num_heads, len_tokens, _ = tensor_shape
    print(f"The shape of the tensor is {num_layers} x {num_heads} x {len_tokens} x {len_tokens}")

    remained_indices = [i for i in range(len_tokens) if i not in modified_indices]
    print(f"The remained indices are {remained_indices} and the removed indices are {modified_indices}")

    
    # Remove the special tokens rows and columns
    modified_tensor = modified_tensor[:, :, remained_indices, :]
    modified_tensor = modified_tensor[:, :, :, remained_indices]
        
    # Check the shape of the modified tensor
    print(f"The shape of the modified  tensor is {modified_tensor.shape}")

    return modified_tensor

def get_normalized_attribution(tensor, input_epsilon:1e-10):
    """
    Calculate the normalized scores for a given tensor.

    Parameters:
        tensor (torch.Tensor): The input tensor.
        input_epsilon (float): The epsilon value.

    Returns:
        torch.Tensor: The normalized hadamard product.
    """

    # get the minimum positive value in the hadamard product among all dimensions
    non_neg_prod = torch.clamp(tensor, min=0)
    min_pos = (non_neg_prod[non_neg_prod > 0]).min()

    # get the epsilon value
    epsilon = torch.max(10**(torch.floor(torch.log10(min_pos))), torch.tensor(input_epsilon))

    # Set the minimum positive value to epsilon to avoid division by zero
    final_tensor = non_neg_prod + epsilon

    # normalize the hadamard product with it sum along the last dimension
    normalized_hadamard_prod = torch.div(final_tensor, final_tensor.sum(dim=-1, keepdim=True))

    return normalized_hadamard_prod

def add_residual_block(attentions_tensor: torch.Tensor, add_res: bool = True, lambda_res=0.5) -> torch.Tensor:
    """
    Computes the aggregated attention weights for a residual block.

    Args:
        attentions (List[torch.Tensor]): 
        A list of attention weight tensors for each layer of the residual block.
        add_res (bool, optional): 
        Whether to add the residual connection to the aggregated attention weights. Defaults to True.

    Returns:
        A tensor containing the aggregated attention weights for the residual block.
    """
    mean_attention_tensor = attentions_tensor.mean(axis=1)
    print(f'The shape of the mean attention tensor is {mean_attention_tensor.shape}')
    
    identity_tensor = torch.eye(mean_attention_tensor.shape[1]).unsqueeze(0)

    if add_res:
        agg_attentions_tensor = ((1-lambda_res)*mean_attention_tensor + lambda_res*identity_tensor)
    else:
        agg_attentions_tensor = mean_attention_tensor

    return agg_attentions_tensor