import numpy as np
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
from .base_handler import BaseClusteringHandler


class DBSCANHandler(BaseClusteringHandler):

    def __init__(self, random_state: int = 42, use_pca: bool = True, **kwargs):
        removed_params = []
        for param in ['n_clusters', 'min_clusters', 'max_clusters']:
            if param in kwargs:
                removed_params.append(f"{param}={kwargs[param]}")
                kwargs.pop(param)

        if removed_params:
            print(f"WARNING: DBSCANHandler ignoring cluster number related parameters ({', '.join(removed_params)}), will automatically determine optimal parameters")

        super().__init__(n_clusters=10, random_state=random_state, **kwargs)

        self.use_pca = use_pca
        self.pca_components = 32
        self.pca = None

        self.eps = kwargs.get('eps', None)
        self.min_samples = kwargs.get('min_samples', None)
        self.metric = kwargs.get('metric', 'euclidean')
        self.algorithm = kwargs.get('algorithm', 'auto')
        self.leaf_size = kwargs.get('leaf_size', 30)
        self.n_jobs = kwargs.get('n_jobs', None)

        self.eps_percentile = kwargs.get('eps_percentile', 70)

        self.actual_n_clusters = 0

    def _estimate_eps(self, data: np.ndarray) -> float:
        from sklearn.neighbors import NearestNeighbors

        k = self.min_samples

        nbrs = NearestNeighbors(n_neighbors=k, metric=self.metric).fit(data)
        distances, indices = nbrs.kneighbors(data)

        k_distances = distances[:, k-1]
        k_distances = np.sort(k_distances)

        eps = np.percentile(k_distances, self.eps_percentile)

        return eps
    
    def fit_predict(self, data: np.ndarray) -> np.ndarray:
        n_samples, n_features = data.shape

        clustering_data = data

        if self.use_pca:
            effective_pca_components = min(self.pca_components, n_samples, n_features)
            if effective_pca_components < self.pca_components:
                self.pca = PCA(n_components=effective_pca_components, random_state=self.random_state)
            else:
                self.pca = PCA(n_components=self.pca_components, random_state=self.random_state)
            clustering_data = self.pca.fit_transform(data)
            clustering_features = effective_pca_components
        else:
            clustering_features = n_features

        if self.min_samples is None:
            self.min_samples = clustering_features + 1

        if n_samples < self.min_samples:
            labels = np.arange(n_samples)
            self.actual_n_clusters = n_samples
            return labels

        eps = self.eps
        if eps is None:
            eps = self._estimate_eps(clustering_data)

        self.clusterer = DBSCAN(
            eps=eps,
            min_samples=self.min_samples,
            metric=self.metric,
            algorithm=self.algorithm,
            leaf_size=self.leaf_size,
            n_jobs=self.n_jobs
        )

        original_labels = self.clusterer.fit_predict(clustering_data)

        labels = self._process_noise_points(original_labels)

        unique_labels = np.unique(labels)
        self.actual_n_clusters = len(unique_labels)

        return labels

    def _process_noise_points(self, labels: np.ndarray) -> np.ndarray:
        processed_labels = labels.copy()

        noise_indices = np.where(labels == -1)[0]

        if len(noise_indices) == 0:
            return processed_labels

        valid_labels = labels[labels != -1]
        if len(valid_labels) > 0:
            max_label = np.max(valid_labels)
        else:
            max_label = -1

        for i, noise_idx in enumerate(noise_indices):
            new_label = max_label + 1 + i
            processed_labels[noise_idx] = new_label

        return processed_labels
    
    def get_cluster_centers(self, data: np.ndarray, labels: np.ndarray) -> np.ndarray:
        unique_labels = np.unique(labels)

        if len(unique_labels) == 0:
            return np.mean(data, axis=0, keepdims=True)

        n_features = data.shape[1]
        centers = np.zeros((len(unique_labels), n_features))

        for i, label in enumerate(unique_labels):
            mask = labels == label
            cluster_points = data[mask]

            if len(cluster_points) > 0:
                centers[i] = np.mean(cluster_points, axis=0)

        return centers
    
    def cluster(self, data: np.ndarray) -> np.ndarray:
        labels = self.fit_predict(data)
        cluster_centers = self.get_cluster_centers(data, labels)
        return cluster_centers

    def get_algorithm_name(self) -> str:
        return 'dbscan'

    def get_info(self) -> dict:
        info = super().get_info()
        info.update({
            'eps': self.eps,
            'eps_percentile': self.eps_percentile,
            'min_samples': self.min_samples,
            'metric': self.metric,
            'use_pca': self.use_pca,
            'pca_components': self.pca_components if self.use_pca else None,
            'actual_n_clusters': self.actual_n_clusters,
            'noise_handling': 'convert_to_individual_clusters',
            'sample_check': 'skip_clustering_if_insufficient_samples',
            'description': 'DBSCAN with PCA option, noise points converted to individual clusters, and sample count validation'
        })
        return info
