"""
cluster_analysis.py

This script performs cluster analysis using embeddings and ground truth labels.

Usage:
- Call the cluster_analysis function with appropriate file name and task number.

"""

from sklearn.metrics import davies_bouldin_score as dbs
import numpy as np
from sklearn.metrics import silhouette_score as ss
from sklearn.metrics import calinski_harabasz_score as chs

def cluster_analysis(file_name, task_no):
    """
    Perform cluster analysis using embeddings and ground truth labels.

    Parameters:
        file_name (str): Name of the file/dataset.
        task_no (int): Task number for analysis.

    Returns:
        dbs_score (float): Davies-Bouldin score for the clustering.
        chs_score (float): Calinski-Harabasz score for the clustering.
    """

    # Constructing file paths for embeddings and ground truth
    emb_file = "../predictions/" + file_name + "/embeddings_task_"+ task_no + ".npy"
    ground_truth = "../predictions/" + file_name + "/predictions_task_"+ task_no + "_tr_pr.npy"
    
    # Loading embeddings and ground truth labels
    points = np.load(emb_file)
    ground_truth = np.load(ground_truth)

    # Calculating cluster analysis scores
    dbs_score = dbs(X=points, labels=ground_truth[:,1])
    chs_score = chs(X=points, labels=ground_truth[:,1])
    return dbs_score, chs_score


if __name__ == "__main__":
    # Example usage
    file_name = "CIFAR10_blcl_1"
    task_no = 3
    cluster_analysis(file_name, task_no)

