

import os
import numpy as np
from sklearn.metrics import (
    normalized_mutual_info_score,
    adjusted_rand_score,
    precision_score,
    recall_score,
    f1_score
)
from scipy.optimize import linear_sum_assignment
from collections import Counter

def compute_label_alignment(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    mapping = {row: col for row, col in zip(row_ind, col_ind)}
    y_pred_aligned = np.array([mapping[label] for label in y_pred])
    acc = sum(w[i, j] for i, j in zip(row_ind, col_ind)) / y_pred.size
    return acc, y_pred_aligned

def purity_score(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    total = 0
    for cluster in np.unique(y_pred):
        indices = np.where(y_pred == cluster)[0]
        true_labels = y_true[indices]
        most_common = Counter(true_labels).most_common(1)
        if most_common:
            total += most_common[0][1]
    return total / len(y_true)

def evaluate(y_true, y_pred, method='macro'):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    # ACC & aligned labels
    acc, y_pred_aligned = compute_label_alignment(y_true, y_pred)

    # Metrics
    nmi = normalized_mutual_info_score(y_true, y_pred)
    purity = purity_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred_aligned, average=method, zero_division=0)
    recall = recall_score(y_true, y_pred_aligned, average=method, zero_division=0)
    f1 = f1_score(y_true, y_pred_aligned, average=method, zero_division=0)
    ari = adjusted_rand_score(y_true, y_pred)

    return np.array([acc, nmi, purity, f1, precision, recall, ari])

import torch
import sys
Lib_root_path = os.path.abspath('./AnonymousLibrary')
sys.path.insert(0, Lib_root_path)
from AnonymousLibrary.manifolds import ProductManifold
from AnonymousLibrary.clustering.fuzzy_kmeans import RiemannianFuzzyKMeans
import numpy as np

# 1. Define the signature: a 3-factor manifold
import numpy as np
#    (curvature, dimension)
signature = [
    (0.0, 16),   # R^4 (Euclidean space)
    (1.0, 16),   # S^4 (Spherical space)
    (-1.0, 16),  # H^4 (Hyperbolic space)
]

# 2. Construct the ProductManifold (without stereographic projection)
P = ProductManifold(signature, device="cpu", stereographic=False)

#setting param
n_clusters = 3

opt = 'adan'
lr = .1
tol = 1e-2

# 3. Generate data using gaussian_mixture
#    - num_points=500: sample 500 points
#    - num_classes=n_clusters: generate n_clusters class labels (for clustering)
#    - seed=seed: fix the random seed for reproducibility
X, y_true = P.gaussian_mixture(
    num_points=1000,
    num_classes=n_clusters,
    task="classification",
    cov_scale_points=.1 # <--- try decreasing this value
    ,seed=4
)
y_true = np.array(y_true)

# 3. Generate data using gaussian_mixture
#    - num_points=500: sample 500 points
#    - num_classes=n_clusters: generate n_clusters class labels (for clustering)
#    - seed=seed: fix the random seed for reproducibility


model = RiemannianFuzzyKMeans(n_clusters, 
            pm=P, 
            max_iter=100,
            tol=tol,
            optimizer=opt,
            lr=lr,
            verbose=True,
            random_state=1)
labels = model.fit_predict(X)

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=n_clusters)
# Fit the data
kmeans.fit(X)
# Get the cluster labels from kmeans
labels_km = kmeans.labels_

result = evaluate(y_true, labels).reshape(1, -1)
result2 = evaluate(y_true, labels_km).reshape(1, -1)
print(result)
print(result2)
X=np.array(X)
y_true=np.array(y_true)