import wandb
import matplotlib.pyplot as plt
import numpy as np

def attention_matrix_to_image(attention_matrix):
    fig, ax = plt.subplots(figsize=(10, 10))
    cax = ax.matshow(attention_matrix, cmap='viridis')
    plt.colorbar(cax)
    
    # for i in range(attention_matrix.shape[0]):
    #     for j in range(attention_matrix.shape[1]):
    #         ax.text(j, i, f"{attention_matrix[i, j]:.3f}", ha='center', va='center', fontsize=10, color="white")

    plt.title("Attention Matrix")
    plt.xlabel("Queries")
    plt.ylabel("Keys")
    # plt.close() # Prevents the image from being displayed directly
    return plt
