import argparse
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score, davies_bouldin_score

parser = argparse.ArgumentParser()
parser.add_argument("--X_path", type=str, required=True)
parser.add_argument("--y_path", type=str, required=True)
args = parser.parse_args()

X = np.load(args.X_path).astype(np.float32)
y = np.load(args.y_path)
print("Loaded:", X.shape, y.shape)

d = 2
start_time = time.time()

tsne = TSNE(n_components=d, random_state=42, method="barnes_hut")
embedding = tsne.fit_transform(X).astype(np.float32)

elapsed = time.time() - start_time
print(f"t-SNE on DINOv2 features took {elapsed:.2f} seconds.")

silhouette_val = silhouette_score(embedding, y)
dbi_val = davies_bouldin_score(embedding, y)
print("Silhouette Score:", silhouette_val)
print("Davies-Bouldin Index:", dbi_val)

plt.figure(figsize=(8, 8))
unique_labels = np.unique(y)
for lbl in unique_labels:
    idxs = np.where(y == lbl)[0]
    plt.scatter(embedding[idxs, 0], embedding[idxs, 1], label=f"Class {lbl}", alpha=0.6)
plt.legend()
plt.savefig("tsne.png", dpi=300)
plt.show()
