import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler


import torch

# Load the features (make sure the extracted features are saved as .npy files)
##mnist_features, mnist_labels = torch.load("./features/mnist_features_clip.tar")
##fmnist_features, fmnist_labels = torch.load("./features/fmnist_features_clip.tar")
##
### Optionally, you can apply StandardScaler to normalize the features
##scaler = StandardScaler()
##
### Scale the features (fit and transform separately for MNIST and FMNIST)
##mnist_features_scaled = scaler.fit_transform(mnist_features)
##fmnist_features_scaled = scaler.fit_transform(fmnist_features)

cdist, x, y = torch.load('./features/cifar10_group2_clip_98pcas_cdist_3000.tar')

# Apply t-SNE or UMAP for dimensionality reduction
# Option 1: t-SNE for 2D visualization
tsne = TSNE(perplexity=40, n_components=2, random_state=42, metric="precomputed", init='random')
reduced_features_mnist_tsne = tsne.fit_transform(cdist)
# reduced_features_fmnist_tsne = tsne.fit_transform(fmnist_features_scaled)

# Option 2: UMAP for 2D visualization
# umap_model = umap.UMAP(n_components=2, random_state=42)
# reduced_features_mnist_umap = umap_model.fit_transform(mnist_features_scaled)
# reduced_features_fmnist_umap = umap_model.fit_transform(fmnist_features_scaled)

# Visualization for MNIST
plt.figure(figsize=(10, 8))
plt.scatter(reduced_features_mnist_tsne[:, 0], reduced_features_mnist_tsne[:, 1], c=y, alpha=0.6, cmap='viridis')
# For UMAP, use:
# plt.scatter(reduced_features_mnist_umap[:, 0], reduced_features_mnist_umap[:, 1], c='blue', alpha=0.6)
plt.title("t-SNE Visualization of FMNIST group2 Features (Scaled)")
plt.xlabel("Component 1")
plt.ylabel("Component 2")
plt.colorbar(label='Dataset')
plt.show()

# # Visualization for FMNIST
# plt.figure(figsize=(10, 8))
# plt.scatter(reduced_features_fmnist_tsne[:, 0], reduced_features_fmnist_tsne[:, 1], c='red', alpha=0.6)
# # For UMAP, use:
# # plt.scatter(reduced_features_fmnist_umap[:, 0], reduced_features_fmnist_umap[:, 1], c='red', alpha=0.6)
# plt.title("t-SNE Visualization of FMNIST Features (Scaled)")
# plt.xlabel("Component 1")
# plt.ylabel("Component 2")
# plt.colorbar(label='Dataset')
# plt.show()
