# This code accompanies the paper: "Angle K-Means" submitted to ICLR 2026.
import time
import scipy.sparse as sp
import scipy.io as sio
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment
import sys, os
sys.path.insert(0, os.path.abspath("src"))
from anglekmeans import mykm

def accuracy(y_true, y_pred):
    """Get the best accuracy.

    Parameters
    ----------
    y_true: array-like
        The true row labels, given as external information
    y_pred: array-like
        The row labels predicted by the model

    Returns
    -------
    float
        Best value of accuracy
    """

    cm = confusion_matrix(y_true=y_true, y_pred=y_pred)
    cost_m = np.max(cm) - cm
    indices = linear_sum_assignment(cost_m)
    indices = np.asarray(indices)
    indexes = np.transpose(indices)
    total = 0
    for row, column in indexes:
        value = cm[row][column]
        total += value
    return total * 1. / np.sum(cm)

# Parameter Configuration
name = "RaFD"                                   # dataset name
k = 50                                          # number of clusters
max_iter = 1000                                 # maximum number of iterations
data_path=f"demo_data/{name}.mat"                              # dataset paths
centroids_path = f"demo_data/{name}_k{k}_centroids.mat"        # centroids paths

# Load datasets
data = sio.loadmat(data_path)
X = np.ascontiguousarray(data["X"].astype(np.float32))
y_true = data["y_true"].astype(np.int32).reshape(-1)
num, dim = X.shape

# Load centroids 
center_data = sio.loadmat(centroids_path)
C = np.ascontiguousarray(center_data["centroids"])
random_state = int(center_data['seed'][0][0])

# Print datasets information
print(f"seed = {random_state}")
print("\nData Info")
print(f"n = {center_data['n'][0][0]}")
print(f"d = {center_data['d'][0][0]}")
print(f"k = {k}")

# Lloyd
print("\nLloyd")

X2 = X.copy()
C2 = C.copy()
mod = KMeans(n_clusters=k, init=C2, n_init=1,
                max_iter=max_iter, tol=1e-6, algorithm='lloyd')
mod.fit(X2)
labels = mod.labels_
acc = accuracy(y_true, labels)
print(f"Iteration = {mod.n_iter_}")
print(f"Per-iteration average of distance computation = {num*k:.1e}")
print(f"acc = {acc:.1e}")

# Angle
print("\nAngle")
X2 = X.copy()
C2 = C.copy()
labels, count, ctime, iters = mykm._kmeans_single_lloyd(
    X2, C2,max_iter=max_iter,verbose=False,tol=1e-6,n_threads=-1)
labels2 = labels
acc = accuracy(y_true, labels)
print(f"Iteration = {iters}")
print(f"Per-iteration average of distance computation = {(count.sum())/iters:.1e}")
print(f"acc = {acc:.1e}")
print("Label difference =",np.sum(labels2 != labels))