#!/usr/bin/env python3

"""
K-means clustering on numpy data
Load .npy file and perform k-means clustering with different k values
Save each k-means model using joblib

"""

import argparse
import os
import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import joblib
import time

def main():
    parser = argparse.ArgumentParser(description='K-means clustering on numpy data')
    parser.add_argument('-f', '--file', type=str, required=True, 
                        help='Path to the .npy file containing 2D numpy matrix')
    parser.add_argument('-k', '--k_values', type=int, nargs='+', 
                        default=[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024],
                        help='List of k values for k-means clustering')
    parser.add_argument('-o', '--output_dir', type=str, default='kmeans_models',
                        help='Directory to save k-means models')
    parser.add_argument('-n', '--n_jobs', type=int, default=1,
                        help='Number of parallel jobs (-1 for all cores)')
    parser.add_argument('--max_iter', type=int, default=300,
                        help='Maximum iterations for k-means algorithm')
    parser.add_argument('--random_state', type=int, default=42,
                        help='Random state for reproducibility')
    
    args = parser.parse_args()
    

    if not os.path.exists(args.file):
        raise FileNotFoundError(f"File not found: {args.file}")
    

    os.makedirs(args.output_dir, exist_ok=True)
    
    print(f"Loading data from: {args.file}")

    data = np.load(args.file)
    

    if len(data.shape) != 2:
        raise ValueError(f"Expected 2D array, got {len(data.shape)}D array")
    
    print(f"Data shape: {data.shape}")
    print(f"Data type: {data.dtype}")
    print(f"Memory usage: {data.nbytes / 1024 / 1024:.2f} MB")
     
    for k in args.k_values:
        print(f"\n{'='*50}")
        print(f"Running k-means with k={k}")
        print(f"{'='*50}")
        
        start_time = time.time()
        

        kmeans = KMeans(
        n_clusters=k,
        random_state=42,
        max_iter=2000,   
        n_init=100,      
        tol=1e-8,          
        verbose=0,
)
        
        print("Training k-means...")
        kmeans.fit(data)


        elapsed_time = time.time() - start_time
        
   
        print(f"\nK-means with k={k} completed in {elapsed_time:.2f} seconds")
        print(f"Inertia (within-cluster sum of squares): {kmeans.inertia_:.4f}")
        print(f"Number of iterations: {kmeans.n_iter_}")
        
        model_filename = f'kmeans_k{k}.pkl'
        model_path = os.path.join(args.output_dir, model_filename)
        joblib.dump(kmeans, model_path)
        
        print(f"Model saved to: {model_path}")
        

        labels_filename = f'labels_k{k}.npy'
        labels_path = os.path.join(args.output_dir, labels_filename)
        np.save(labels_path, kmeans.labels_)
        print(f"Cluster labels saved to: {labels_path}")
    
    print("\n" + "="*50)
    print("All k-means clustering completed!")
    print(f"Models saved in: {args.output_dir}")
    print("="*50)

if __name__ == '__main__':
    main()