import torch
import torch.nn.functional as F

def relative_l1_norm(attention_map1, attention_map2):
    """
    Calculate the Relative L1 Norm between two attention maps.
    Normalized by the L1 norm of the second attention map.
    """
    l1_norm = torch.sum(torch.abs(attention_map1 - attention_map2))
    normalization = torch.sum(torch.abs(attention_map2))
    return l1_norm / (normalization + 1e-10)  # Add small epsilon to avoid division by zero

def root_mean_squared_error(attention_map1, attention_map2):
    """
    Calculate the Root Mean Squared Error (RMSE) between two attention maps.
    """
    mse = torch.mean((attention_map1 - attention_map2) ** 2) / attention_map1.numel()
    return torch.sqrt(mse)

def cosine_similarity(attention_map1, attention_map2):
    """
    Calculate the Cosine Similarity between two attention maps.
    """
    attention_map1_flat = attention_map1.flatten()
    attention_map2_flat = attention_map2.flatten()
    return F.cosine_similarity(attention_map1_flat.unsqueeze(0), attention_map2_flat.unsqueeze(0))

def evaluate_attention_maps(attention_map1, attention_map2):
    """
    Evaluate the difference between two attention maps using Relative L1 Norm, RMSE, and Cosine Similarity.
    
    Parameters:
    attention_map1 (torch.Tensor): The first attention map with shape [BS, num_head, N, N].
    attention_map2 (torch.Tensor): The second attention map with shape [BS, num_head, N, N].
    
    Returns:
    dict: A dictionary containing the Relative L1 Norm, RMSE, and Cosine Similarity between the two maps.
    """
    if attention_map1.shape != attention_map2.shape:
        raise ValueError("Attention maps must have the same shape.")
    
    # Initialize metrics
    relative_l1 = 0
    rmse = 0
    cos_sim = 0
    
    # Iterate over batch size and number of heads
    bs, num_head, N, _ = attention_map1.shape
    for i in range(bs):
        for j in range(num_head):
            map1 = attention_map1[i, j]
            map2 = attention_map2[i, j]
            
            # Calculate metrics for each head and accumulate
            relative_l1 += relative_l1_norm(map1, map2)
            rmse += root_mean_squared_error(map1, map2)
            cos_sim += cosine_similarity(map1, map2)
    
    # Normalize by the total number of heads and batch size
    relative_l1 /= (bs * num_head)
    rmse /= (bs * num_head)
    cos_sim /= (bs * num_head)
    
    return {
        'Relative L1 Norm': relative_l1.item(),
        'RMSE': rmse.item(),
        'Cosine Similarity': cos_sim.item()
    }

# Example usage
if __name__ == "__main__":
    # Check if CUDA is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Example attention maps with shape [BS, num_head, N, N]
    BS = 2  # Batch size
    num_head = 4  # Number of attention heads
    N = 8  # Sequence length (or spatial dimension)
    
    attention_map1 = torch.rand(BS, num_head, N, N).to(device)  # Random attention map
    attention_map2 = torch.rand(BS, num_head, N, N).to(device)  # Another random attention map
    
    # Evaluate the difference between the two attention maps
    results = evaluate_attention_maps(attention_map1, attention_map2)
    
    # Print the results
    print("Relative L1 Norm:", results['Relative L1 Norm'])
    print("RMSE:", results['RMSE'])
    print("Cosine Similarity:", results['Cosine Similarity'])