#!/usr/bin/env python3
"""
Fast Clustering Module - Optimized for Speed

This module provides fast clustering with variance-based k-selection using only
numerical metrics (silhouette score, inertia) to avoid expensive semantic evaluation.
"""

import os
import numpy as np
from typing import List, Dict, Any, Tuple
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import silhouette_score
from tqdm import tqdm
from typing import Tuple
from sklearn.metrics import pairwise_distances

class Clusterer:
    """Fast clustering with variance-based k-range selection using only numerical metrics"""
    
    def __init__(self, max_k_samples: int = 5, variance_threshold: float = 0.3):
        self.max_k_samples = max_k_samples
        self.variance_threshold = variance_threshold
    
    def calculate_variance_based_k_range(self, embeddings: np.ndarray) -> Tuple[int, int]:
        """Calculate k-range based on data variance (FAST VERSION)"""
        print("📊 Calculating variance-based k-range (fast)...")

        n_samples = len(embeddings)
        
        # 1. Global variance (fast)
        variance = np.var(embeddings, axis=0).mean()
        print(f"   Data variance: {variance:.4f}")
        
        # 2. Fast local density estimation using sampling (instead of full pairwise distances)
        # Sample 1000 points for nearest neighbor estimation
        sample_size = min(1000, n_samples)
        if n_samples > sample_size:
            indices = np.random.choice(n_samples, sample_size, replace=False)
            sample_embeddings = embeddings[indices]
        else:
            sample_embeddings = embeddings
        
        # Use approximate nearest neighbors (much faster)
        from sklearn.neighbors import NearestNeighbors
        
        # Check if we have enough samples for nearest neighbors
        if len(sample_embeddings) < 2:
            print(f"   ⚠️ Too few samples ({len(sample_embeddings)}) for nearest neighbors - using fallback")
            mean_nn_dist = 1.0  # Fallback distance
        else:
            nn = NearestNeighbors(n_neighbors=min(2, len(sample_embeddings)), algorithm='ball_tree')
            nn.fit(sample_embeddings)
            distances, _ = nn.kneighbors(sample_embeddings)
            mean_nn_dist = np.mean(distances[:, 1])  # Skip self-distance (index 0)
            print(f"   Mean NN distance (sampled): {mean_nn_dist:.4f}")
        
        # Dynamic k-range based on dataset size and variance
        if n_samples < 2:
            # Too few samples for clustering
            min_k = 1
            max_k = 1
            print(f"   ⚠️ Too few samples ({n_samples}) - using single cluster")
        elif n_samples > 15000:
            # For very large datasets, use intelligent k-range
            min_k = 15
            max_k = 60  # Allow more k values for better optimization
            print(f"   Large dataset ({n_samples:,}): using dynamic k-range {min_k}-{max_k}")
        elif n_samples > 5000:
            # Medium datasets
            min_k = 10
            max_k = 40
            print(f"   Medium dataset ({n_samples:,}): using k-range {min_k}-{max_k}")
        elif variance > self.variance_threshold:
            # High variance: use larger k-range
            min_k = max(10, int(np.sqrt(n_samples) / 4))
            max_k = min(int(np.sqrt(n_samples) * 2), 50)
            print(f"   High variance detected: using k-range {min_k}-{max_k}")
        else:
            # Low variance: use smaller k-range
            min_k = max(5, int(np.sqrt(n_samples) / 8))
            max_k = min(int(np.sqrt(n_samples) * 1.5), 30)
            print(f"   Low variance detected: using k-range {min_k}-{max_k}")
            
            # Check for high overlap
            overlap_threshold = 0.5
            if mean_nn_dist < overlap_threshold:
                max_k = min(max_k * 2, 50)
                print(f"   High overlap detected — bumping max_k to {max_k}")
        
        return min_k, max_k
    
    def sparse_sample_k_values(self, min_k: int, max_k: int) -> List[int]:
        """Intelligent sampling of k-values for dynamic clustering"""
        k_range = range(min_k, max_k + 1)
        
        # For large k-ranges, use more samples for better optimization
        if max_k - min_k > 40:
            max_samples = 15  # More samples for large ranges
        elif max_k - min_k > 20:
            max_samples = 10  # Medium samples
        else:
            max_samples = self.max_k_samples  # Default 5
        
        if len(k_range) <= max_samples:
            return list(k_range)
        
        # Intelligent sparse sampling
        selected = []
        
        # Always include boundaries
        selected.extend([min_k, max_k])
        
        # Add logarithmically spaced points
        remaining = max_samples - len(selected)
        if remaining > 0:
            log_indices = np.logspace(0, np.log10(len(k_range)-1), remaining + 2)[1:-1]
            log_k_values = [k_range[int(round(idx))] for idx in log_indices]
            selected.extend(log_k_values)
        
        # Remove duplicates and sort
        selected = sorted(list(set(selected)))[:max_samples]
        
        print(f"   Dynamic sampling: testing {len(selected)} k-values: {selected}")
        return selected
    
    def evaluate_k_value_fast(self, embeddings: np.ndarray, k: int) -> Dict[str, Any]:
        """Ultra-fast evaluation using only inertia (no silhouette score)"""
        try:
            # Optimized settings for speed
            batch_size = min(2000, len(embeddings) // 10)
            
            kmeans = MiniBatchKMeans(
                n_clusters=k, 
                random_state=42, 
                n_init=1,  # Single init for speed
                batch_size=batch_size,
                max_iter=100,  # Fewer iterations for speed
                init_size=min(3 * k, len(embeddings))
            )
            labels = kmeans.fit_predict(embeddings)
            
            # Calculate silhouette score (now feasible with fewer, filtered chunks)
            try:
                if len(embeddings) > 1 and k > 1:
                    silhouette = silhouette_score(embeddings, labels)
                else:
                    silhouette = 0  # Cannot calculate for single cluster or single sample
            except Exception as e:
                print(f"      Warning: Could not calculate silhouette score for k={k}: {e}")
                silhouette = 0
            
            return {
                'k': k,
                'silhouette_score': silhouette,
                'inertia': kmeans.inertia_,
                'labels': labels,
                'cluster_centers': kmeans.cluster_centers_
            }
        except Exception as e:
            print(f"   Warning: Failed to evaluate k={k}: {e}")
            return {
                'k': k,
                'silhouette_score': 0,
                'inertia': float('inf'),
                'labels': None,
                'cluster_centers': None
            }
    
    def find_optimal_k_fast(self, embeddings: np.ndarray) -> Dict[str, Any]:
        """Find optimal k using silhouette score (now feasible with filtered chunks)"""
        print("🎯 Finding optimal k with silhouette score optimization...")
        
        # Calculate k-range based on variance
        min_k, max_k = self.calculate_variance_based_k_range(embeddings)
        
        # Sparse sample k-values
        k_values = self.sparse_sample_k_values(min_k, max_k)
        
        # Evaluate each k-value (fast)
        results = []
        for k in tqdm(k_values, desc="Evaluating k-values (fast)"):
            result = self.evaluate_k_value_fast(embeddings, k)
            results.append(result)
        
        # Find best k based on silhouette score (better quality with fewer chunks)
        # Use silhouette score for k-selection (better quality)
        best_result = max(results, key=lambda x: x["silhouette_score"])
        print("   🎯 Selected k={} with silhouette score: {:.3f}".format(best_result["k"], best_result["silhouette_score"]))
        
        return best_result
    
    def cluster_fast(self, embeddings: np.ndarray) -> Dict[str, Any]:
        """Run fast clustering with variance-based k selection"""
        print("🚀 Running fast clustering...")

        def _normalize_embeddings(embeddings: np.ndarray) -> np.ndarray:
            norms = np.linalg.norm(embeddings, axis=1, keepdims=True)  # compute L2 norm per vector
            norms[norms == 0] = 1  # avoid division by zero for zero vectors
            normalized = embeddings / norms
            return normalized
        
        # Find optimal k (fast)
        embeddings = _normalize_embeddings(embeddings)
        optimal_result = self.find_optimal_k_fast(embeddings)
        
        # Return clustering results
        cluster_results = {
            'labels': optimal_result['labels'],
            'cluster_centers': optimal_result['cluster_centers'],
            'optimal_k': optimal_result['k'],
            'silhouette_score': optimal_result['silhouette_score'],
            'inertia': optimal_result['inertia'],
            'num_clusters': len(np.unique(optimal_result['labels']))
        }
        
        print(f"✅ Fast clustering completed: {cluster_results['num_clusters']} clusters")
        print(f"   Silhouette score: {cluster_results['silhouette_score']:.3f}")
        print(f"   Inertia: {cluster_results['inertia']:.0f}")
        
        return cluster_results

def cluster_fast(embeddings: np.ndarray, 
                max_k_samples: int = 5,
                variance_threshold: float = 0.3) -> Dict[str, Any]:
    """
    Fast clustering function that can be called directly
    
    Args:
        embeddings: Embedding vectors
        codes: Text codes corresponding to embeddings
        max_k_samples: Maximum number of k-values to test
        variance_threshold: Threshold for high variance detection
        
    Returns:
        Dictionary containing clustering results
    """
    clusterer = Clusterer(
        max_k_samples=max_k_samples,
        variance_threshold=variance_threshold
    )
    
    return clusterer.cluster_fast(embeddings)

# Backward compatibility function
def cluster():
    """Backward compatibility function that loads from files and saves results"""
    import pandas as pd
    
    # Load embeddings
    embeddings_path = os.path.join(os.path.dirname(__file__), "..", "temp_files", "embeddings.parquet")
    embeddings_df = pd.read_parquet(embeddings_path)
    
    embeddings = np.array([emb for emb in embeddings_df['embedding']])
    codes = embeddings_df['text'].tolist()
    
    # Run fast clustering
    results = cluster_fast(embeddings, codes)
    
    # Save results (for backward compatibility)
    clusters_dir = os.path.join(os.path.dirname(__file__), "..", "temp_files", "clusters")
    os.makedirs(clusters_dir, exist_ok=True)
    
    # Save cluster assignments
    cluster_df = pd.DataFrame({
        'text': codes,
        'cluster_id': results['labels']
    })
    
    # Save each cluster separately
    unique_labels = np.unique(results['labels'])
    for cluster_id in unique_labels:
        cluster_indices = np.where(results['labels'] == cluster_id)[0]
        cluster_codes = [codes[i] for i in cluster_indices]
        cluster_embeddings = embeddings[cluster_indices]
        
        cluster_df_single = pd.DataFrame({
            'text': cluster_codes,
            'embedding': [emb.tolist() for emb in cluster_embeddings]
        })
        
        cluster_path = os.path.join(clusters_dir, f"cluster_{cluster_id}.parquet")
        cluster_df_single.to_parquet(cluster_path)
    
    print(f"✅ Saved {len(unique_labels)} clusters to {clusters_dir}")
    return results

if __name__ == "__main__":
    # Test the fast clustering
    cluster() 