import numpy as np
import os

base_path = r'cifar_embeddings'


X_train = []
y_train = []
for i in range(10):
    x_i = np.load(os.path.join(base_path, f'X_train_{i}.npy'))
    y_i = np.full((x_i.shape[0],), i)
    X_train.append(x_i)
    y_train.append(y_i)

X_train = np.vstack(X_train)
y_train = np.concatenate(y_train)


X_test = []
y_test = []
for i in range(10):
    x_i = np.load(os.path.join(base_path, f'X_test_{i}.npy'))
    y_i = np.full((x_i.shape[0],), i)
    X_test.append(x_i)
    y_test.append(y_i)

X_test = np.vstack(X_test)
y_test = np.concatenate(y_test)

print("Train shape:", X_train.shape, y_train.shape)
print("Test shape:", X_test.shape, y_test.shape)

#%%
import numpy as np
from sklearn.metrics import (
    normalized_mutual_info_score,
    adjusted_rand_score,
    precision_score,
    recall_score,
    f1_score
)
from scipy.optimize import linear_sum_assignment
from collections import Counter

def compute_label_alignment(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    mapping = {row: col for row, col in zip(row_ind, col_ind)}
    y_pred_aligned = np.array([mapping[label] for label in y_pred])
    acc = sum(w[i, j] for i, j in zip(row_ind, col_ind)) / y_pred.size
    return acc, y_pred_aligned

def purity_score(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    total = 0
    for cluster in np.unique(y_pred):
        indices = np.where(y_pred == cluster)[0]
        true_labels = y_true[indices]
        most_common = Counter(true_labels).most_common(1)
        if most_common:
            total += most_common[0][1]
    return total / len(y_true)

def evaluate(y_true, y_pred, method='macro'):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    # ACC & aligned labels
    acc, y_pred_aligned = compute_label_alignment(y_true, y_pred)

    # Metrics
    nmi = normalized_mutual_info_score(y_true, y_pred)
    purity = purity_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred_aligned, average=method, zero_division=0)
    recall = recall_score(y_true, y_pred_aligned, average=method, zero_division=0)
    f1 = f1_score(y_true, y_pred_aligned, average=method, zero_division=0)
    ari = adjusted_rand_score(y_true, y_pred)

    return np.array([acc, nmi, purity, f1, precision, recall, ari])
#%%
import numpy as np
from sklearn.cluster import KMeans


X = np.vstack([X_train, X_test])
y = np.concatenate([y_train, y_test])


import torch
import sys
Lib_root_path = os.path.abspath('./AnonymousLibrary')
sys.path.insert(0, Lib_root_path)
from AnonymousLibrary.manifolds import ProductManifold
from AnonymousLibrary.clustering.fuzzy_kmeans import RiemannianFuzzyKMeans



signature = [
    (1,2),
    (1,2),
    (1,2),
    (1,2),
]
P = ProductManifold(signature, device="cpu", stereographic=False)


n_clusters = 10  


opt = 'adan'
lr = .5
tol = 1e-2


X_torch = torch.from_numpy(X).float()

model = RiemannianFuzzyKMeans(
    n_clusters=n_clusters,
    pm=P,
    max_iter=100,
    tol=tol,
    optimizer=opt,
    lr=lr,
    verbose=True,
    random_state=2025
)
import time
t1=time.time()
labels = model.fit_predict(X_torch)
t2=time.time()
print(t2-t1)

#%%
result = evaluate(y, labels).reshape(1,-1)


print("RiemannianFuzzyKMeans results (ACC, NMI, Purity, F1, Precision, Recall, ARI):")
print(result)
