import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def plot_ari_heatmap(ari_matrix_path: str, models: list):
    """
    Plot a heatmap of the ARI matrix
    
    Args:
        ari_matrix_path: Path to the .npy file containing the ARI matrix
        models: List of model names in the same order as the matrix
    """
    # Load the ARI matrix
    ari_matrix = np.load(ari_matrix_path)
    
    # Create the heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(ari_matrix, 
                annot=True, 
                fmt='.2f', 
                cmap='YlOrRd',
                xticklabels=models,
                yticklabels=models)
    
    plt.title('Adjusted Rand Index (ARI) Matrix')
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    models = ['BGE', 'Fast-Text', 'gpt3', 'Mistral', 'NV-Embed', 'Word2Vec']
    ari_matrix_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'output', 'ari_scifact.npy')
    plot_ari_heatmap(ari_matrix_path, models) 