import numpy as np
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from .base_handler import BaseClusteringHandler


class KMeansHandler(BaseClusteringHandler):

    def __init__(self, random_state: int = 42, use_pca: bool = True, fixed_n_clusters: int = None, **kwargs):
        self.fixed_n_clusters = fixed_n_clusters
        self.use_adaptive_clustering = fixed_n_clusters is None

        initial_n_clusters = fixed_n_clusters if fixed_n_clusters is not None else 10

        removed_params = []
        for param in ['n_clusters', 'min_clusters', 'max_clusters']:
            if param in kwargs:
                removed_params.append(f"{param}={kwargs[param]}")
                kwargs.pop(param)

        if removed_params:
            if self.use_adaptive_clustering:
                print(f"WARNING: KMeansHandler ignoring cluster number related parameters ({', '.join(removed_params)}), will automatically determine optimal cluster number")
            else:
                print(f"WARNING: KMeansHandler ignoring cluster number related parameters ({', '.join(removed_params)}), using fixed cluster number: {fixed_n_clusters}")

        super().__init__(n_clusters=initial_n_clusters, random_state=random_state, **kwargs)

        self.use_pca = use_pca
        self.pca_components = 32
        self.pca = None

        self.n_init = kwargs.get('n_init', 10)
        self.max_iter = kwargs.get('max_iter', 300)
        self.tol = kwargs.get('tol', 1e-4)
        self.algorithm = kwargs.get('algorithm', 'lloyd')

        self.early_stop_patience = kwargs.get('early_stop_patience', 3)
        self.early_stop_threshold = kwargs.get('early_stop_threshold', 0.01)

        self.min_clusters = 2
        self.max_clusters = 30

    def _determine_optimal_clusters(self, data: np.ndarray) -> int:
        n_samples, n_features = data.shape

        clustering_data = data

        if self.use_pca:
            effective_pca_components = min(self.pca_components, n_samples, n_features)
            if effective_pca_components < self.pca_components:
                self.pca = PCA(n_components=effective_pca_components, random_state=self.random_state)
            else:
                self.pca = PCA(n_components=self.pca_components, random_state=self.random_state)
            clustering_data = self.pca.fit_transform(data)

        max_k = min(self.max_clusters, n_samples - 1)
        min_k = max(self.min_clusters, 2)

        if max_k <= min_k:
            return min_k

        best_score = -1
        best_k = min_k
        scores_history = []
        no_improvement_count = 0

        for k in range(min_k, max_k + 1):
            try:
                kmeans = KMeans(
                    n_clusters=k,
                    random_state=self.random_state,
                    n_init=self.n_init,
                    max_iter=self.max_iter,
                    tol=self.tol,
                    algorithm=self.algorithm
                )

                labels = kmeans.fit_predict(clustering_data)

                if len(np.unique(labels)) > 1:
                    score = silhouette_score(clustering_data, labels)
                    scores_history.append(score)

                    if score > best_score:
                        improvement = score - best_score
                        best_score = score
                        best_k = k

                        if improvement > self.early_stop_threshold:
                            no_improvement_count = 0
                        else:
                            no_improvement_count += 1
                    else:
                        no_improvement_count += 1

                    if (no_improvement_count >= self.early_stop_patience and
                        len(scores_history) >= self.early_stop_patience + 1):

                        recent_scores = scores_history[-self.early_stop_patience:]
                        if self._is_elbow_point(recent_scores):
                            break

            except Exception:
                continue

        return best_k

    def _is_elbow_point(self, recent_scores: list) -> bool:
        if len(recent_scores) < 2:
            return False

        improvements = []
        for i in range(1, len(recent_scores)):
            improvement = recent_scores[i] - recent_scores[i-1]
            improvements.append(improvement)

        small_improvements = sum(1 for imp in improvements if abs(imp) < self.early_stop_threshold)
        return small_improvements >= len(improvements) * 0.8

    def fit_predict(self, data: np.ndarray) -> np.ndarray:
        n_samples, n_features = data.shape

        if self.use_adaptive_clustering:
            optimal_k = self._determine_optimal_clusters(data)
            self.n_clusters = optimal_k
        else:
            self.n_clusters = min(self.fixed_n_clusters, n_samples)
            if self.n_clusters != self.fixed_n_clusters:
                print(f"WARNING: Fixed cluster number ({self.fixed_n_clusters}) exceeds sample count ({n_samples}), adjusted to {self.n_clusters}")

            if self.use_pca and self.pca is None:
                effective_pca_components = min(self.pca_components, n_samples, n_features)
                if effective_pca_components < self.pca_components:
                    self.pca = PCA(n_components=effective_pca_components, random_state=self.random_state)
                else:
                    self.pca = PCA(n_components=self.pca_components, random_state=self.random_state)
                self.pca.fit(data)

        clustering_data = data

        if self.use_pca:
            if self.pca is None:
                effective_pca_components = min(self.pca_components, n_samples, n_features)
                if effective_pca_components < self.pca_components:
                    self.pca = PCA(n_components=effective_pca_components, random_state=self.random_state)
                else:
                    self.pca = PCA(n_components=self.pca_components, random_state=self.random_state)
                clustering_data = self.pca.fit_transform(data)
            else:
                clustering_data = self.pca.transform(data)

        self.clusterer = KMeans(
            n_clusters=self.n_clusters,
            random_state=self.random_state,
            n_init=self.n_init,
            max_iter=self.max_iter,
            tol=self.tol,
            algorithm=self.algorithm
        )

        return self.clusterer.fit_predict(clustering_data)

    def cluster(self, data: np.ndarray) -> np.ndarray:
        labels = self.fit_predict(data)

        if hasattr(self.clusterer, 'cluster_centers_'):
            if self.use_pca and self.pca is not None:
                cluster_centers = self.pca.inverse_transform(self.clusterer.cluster_centers_)
            else:
                cluster_centers = self.clusterer.cluster_centers_
        else:
            cluster_centers = self.get_cluster_centers(data, labels)

        return cluster_centers

    def get_algorithm_name(self) -> str:
        return 'kmeans'

    def get_info(self) -> dict:
        info = super().get_info()
        info.update({
            'n_init': self.n_init,
            'max_iter': self.max_iter,
            'tol': self.tol,
            'algorithm': self.algorithm,
            'use_pca': self.use_pca,
            'pca_components': self.pca_components if self.use_pca else None,
            'early_stop_patience': self.early_stop_patience,
            'early_stop_threshold': self.early_stop_threshold,
            'fixed_n_clusters': self.fixed_n_clusters,
            'use_adaptive_clustering': self.use_adaptive_clustering,
            'description': f'K-means with PCA option for acceleration, {"automatic optimal cluster determination" if self.use_adaptive_clustering else f"fixed cluster count: {self.fixed_n_clusters}"}'
        })
        return info
