import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax import random
import numpy as np
from kmeans import kmeans, hierarchical_kmeans

def test_kmeans_k2():
    # Set random seed for reproducibility
    key = random.PRNGKey(42)
    key1, key2, key3 = random.split(key, 3)
    
    # Generate 64 2D points from a mixture of two Gaussians
    n_points = 64
    
    # First Gaussian: centered at (-2, -2) with std=0.5
    mean1 = jnp.array([-2.0, -2.0])
    std1 = 0.5
    points1 = random.normal(key1, shape=(n_points // 2, 2)) * std1 + mean1
    
    # Second Gaussian: centered at (2, 2) with std=0.5
    mean2 = jnp.array([2.0, 2.0])
    std2 = 0.5
    points2 = random.normal(key2, shape=(n_points // 2, 2)) * std2 + mean2
    
    # Combine points and shuffle
    all_points = jnp.concatenate([points1, points2], axis=0)
    all_points = random.permutation(key3, all_points)
    
    # Run k-means with k=2
    centroids, labels = kmeans(all_points, random.PRNGKey(43), k=2, iters=20)
    
    # Extract the two clusters
    cluster1_points = all_points[labels == 0]
    cluster2_points = all_points[labels == 1]
    
    # Convert to numpy for matplotlib
    cluster1_points_np = np.array(cluster1_points)
    cluster2_points_np = np.array(cluster2_points)
    centroids_np = np.array(centroids)
    
    # Visualize results
    plt.figure(figsize=(10, 8))
    plt.scatter(cluster1_points_np[:, 0], cluster1_points_np[:, 1], c='blue', label='Cluster 1', alpha=0.7)
    plt.scatter(cluster2_points_np[:, 0], cluster2_points_np[:, 1], c='red', label='Cluster 2', alpha=0.7)
    
    # Add true means to the plot for reference
    plt.scatter(mean1[0], mean1[1], c='blue', marker='X', s=200, edgecolor='black', label='True Mean 1')
    plt.scatter(mean2[0], mean2[1], c='red', marker='X', s=200, edgecolor='black', label='True Mean 2')
    
    # Plot the computed centroids
    plt.scatter(centroids_np[0, 0], centroids_np[0, 1], c='cyan', marker='*', s=300, edgecolor='black', label='Computed Centroid 1')
    plt.scatter(centroids_np[1, 0], centroids_np[1, 1], c='orange', marker='*', s=300, edgecolor='black', label='Computed Centroid 2')
    
    plt.title('K-Means Clustering (k=2) on Mixture of Gaussians')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.savefig('results/kmeans_k2_test.png')
    plt.close()
    
    print("K-means with k=2 test completed. Results saved to kmeans_k2_test.png")

def test_kmeans_k4():
    # Set random seed for reproducibility
    key = random.PRNGKey(42)
    keys = random.split(key, 5)
    
    # Generate 128 2D points from a mixture of four Gaussians
    n_points = 128
    points_per_cluster = n_points // 4
    
    # Four Gaussians at different locations
    means = [
        jnp.array([-2.0, -2.0]),
        jnp.array([2.0, -2.0]),
        jnp.array([-2.0, 2.0]),
        jnp.array([2.0, 2.0])
    ]
    std = 0.5
    
    # Generate points for each cluster
    clusters = []
    for i in range(4):
        points = random.normal(keys[i], shape=(points_per_cluster, 2)) * std + means[i]
        clusters.append(points)
    
    # Combine points and shuffle
    all_points = jnp.concatenate(clusters, axis=0)
    all_points = random.permutation(keys[4], all_points)
    
    # Run k-means with k=4
    centroids, labels = kmeans(all_points, random.PRNGKey(43), k=4, iters=20)
    
    # Extract the four clusters
    cluster_points = [all_points[labels == i] for i in range(4)]
    
    # Convert to numpy for matplotlib
    cluster_points_np = [np.array(points) for points in cluster_points]
    centroids_np = np.array(centroids)
    
    # Colors for the clusters
    colors = ['blue', 'red', 'green', 'purple']
    centroid_colors = ['cyan', 'orange', 'lime', 'magenta']
    
    # Visualize results
    plt.figure(figsize=(10, 8))
    
    # Plot each cluster
    for i in range(4):
        if len(cluster_points_np[i]) > 0:  # Check if cluster has points
            plt.scatter(
                cluster_points_np[i][:, 0], 
                cluster_points_np[i][:, 1], 
                c=colors[i], 
                label=f'Cluster {i+1}', 
                alpha=0.7
            )
    
    # Add true means to the plot for reference
    for i, mean in enumerate(means):
        plt.scatter(
            mean[0], 
            mean[1], 
            c=colors[i], 
            marker='X', 
            s=200, 
            edgecolor='black', 
            label=f'True Mean {i+1}'
        )
    
    # Plot the computed centroids
    for i in range(4):
        plt.scatter(
            centroids_np[i, 0], 
            centroids_np[i, 1], 
            c=centroid_colors[i], 
            marker='*', 
            s=300, 
            edgecolor='black', 
            label=f'Computed Centroid {i+1}'
        )
    
    plt.title('K-Means Clustering (k=4) on Mixture of Gaussians')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.savefig('results/kmeans_k4_test.png')
    plt.close()
    
    print("K-means with k=4 test completed. Results saved to kmeans_k4_test.png")

def test_hierarchical_kmeans():
    # Set random seed for reproducibility
    key = random.PRNGKey(42)
    keys = random.split(key, 5)
    
    # Generate 256 2D points from a mixture of four Gaussians
    n_points = 256
    points_per_cluster = n_points // 4
    
    # Four Gaussians at different locations (meta-clusters)
    meta_means = [
        jnp.array([-3.0, -3.0]),
        jnp.array([3.0, -3.0]),
        jnp.array([-3.0, 3.0]),
        jnp.array([3.0, 3.0])
    ]
    
    # For each meta-cluster, we'll create 4 sub-clusters
    sub_offsets = [
        jnp.array([-0.5, -0.5]),
        jnp.array([0.5, -0.5]),
        jnp.array([-0.5, 0.5]),
        jnp.array([0.5, 0.5])
    ]
    
    # Generate all sub-cluster means
    all_means = []
    for meta_mean in meta_means:
        for offset in sub_offsets:
            all_means.append(meta_mean + offset)
    
    # Standard deviation for all clusters
    std = 0.2
    
    # Generate points for each sub-cluster
    clusters = []
    points_per_subcluster = n_points // 16
    subkeys = random.split(keys[0], 16)
    
    for i, mean in enumerate(all_means):
        points = random.normal(subkeys[i], shape=(points_per_subcluster, 2)) * std + mean
        clusters.append(points)
    
    # Combine points and shuffle
    all_points = jnp.concatenate(clusters, axis=0)
    all_points = random.permutation(keys[4], all_points)
    
    # Run hierarchical k-means with k=4 and levels=2
    labels = hierarchical_kmeans(all_points, random.PRNGKey(43), k=4, levels=2, iters=20)
    
    # Create a colormap for visualization
    import matplotlib.colors as mcolors
    import matplotlib.cm as cm
    
    # Convert labels to numpy for visualization
    labels_np = np.array(labels)
    all_points_np = np.array(all_points)
    
    # Create a categorical colormap with enough colors
    num_unique_labels = len(np.unique(labels_np))
    colors = cm.rainbow(np.linspace(0, 1, num_unique_labels))
    
    # Visualize results
    plt.figure(figsize=(12, 10))
    
    # Plot points with colors based on hierarchical labels
    for i, label in enumerate(np.unique(labels_np)):
        mask = labels_np == label
        plt.scatter(
            all_points_np[mask, 0],
            all_points_np[mask, 1],
            color=colors[i],
            label=f'Cluster {label}',
            alpha=0.7
        )
    
    # Plot the true sub-cluster means for reference
    for i, mean in enumerate(all_means):
        meta_idx = i // 4
        sub_idx = i % 4
        plt.scatter(
            mean[0],
            mean[1],
            marker='X',
            s=100,
            edgecolor='black',
            color=colors[i % num_unique_labels],
            alpha=0.7
        )
    
    plt.title('Hierarchical K-Means Clustering (k=4, levels=2)')
    plt.xlabel('X')
    plt.ylabel('Y')
    
    # Add a compact legend for a subset of clusters to avoid overcrowding
    plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1))
    plt.grid(alpha=0.3)
    plt.savefig('results/hierarchical_kmeans_test.png')
    plt.close()
    
    print("Hierarchical k-means test completed. Results saved to results/hierarchical_kmeans_test.png")

def main():
    # Run tests for k=2, k=4, and hierarchical k-means
    test_kmeans_k2()
    test_kmeans_k4()
    test_hierarchical_kmeans()
    print("All kmeans tests completed successfully!")

if __name__ == "__main__":
    main()
