from sklearn_extra.cluster import KMedoids
from sklearn.metrics import pairwise_distances
import numpy as np

class WeightedKMedoids(KMedoids):
    def __init__(self, n_clusters=8, metric='euclidean', init='k-medoids++', max_iter=300,
                 random_state=None):
        super().__init__(n_clusters=n_clusters, metric=metric, init=init, max_iter=max_iter,
                         random_state=random_state)

    def fit(self, X, y=None, sample_weight=None):
        """
        Fits models and supports sample weights.

       Parameters:
        - X: array-like of shape (n_samples, n_features)
        - sample_weight: array-like of shape (n_samples,), default=None

        Returns:
        - self: object
        """
        self.sample_weight_ = np.array(sample_weight)
        return super().fit(X, y)

    def _initialize_medoids(self, D, n_clusters, random_state_, X=None):
        # random_state = self.random_state_
        n_samples = D.shape[0]

        if self.init == 'random':
            medoids = random_state_.choice(n_samples, n_clusters, replace=False)
            return medoids
        elif self.init == 'heuristic':
            # Original heuristic initialization method
            return self._heuristic_init(D, n_clusters)
        elif self.init == 'k-medoids++':
            # k-medoids++ initialization with weights
            if self.sample_weight_ is None:
                sample_weight = np.ones(n_samples)
            else:
                sample_weight = self.sample_weight_

            medoids = []
            # Select the first center point with probability proportional to the sample weight
            probabilities = sample_weight / sample_weight.sum()
            first_medoid = random_state_.choice(n_samples, p=probabilities)
            medoids.append(first_medoid)

            # Initialize the distance from each point to the nearest selected center point
            dist_to_nearest_medoid = D[:, first_medoid]

            for _ in range(1, n_clusters):
                # Calculate weighted distance
                weighted_distances = dist_to_nearest_medoid * sample_weight
                total_weighted_distance = weighted_distances.sum()
                if total_weighted_distance == 0:
                    # All points are the same, randomly select the next center point
                    next_medoid = random_state_.choice(n_samples)
                else:
                    probabilities = weighted_distances / total_weighted_distance
                    next_medoid = random_state_.choice(n_samples, p=probabilities)
                medoids.append(next_medoid)
                # Update the distance from each point to the nearest center point
                dist_to_medoid = D[:, next_medoid]
                dist_to_nearest_medoid = np.minimum(dist_to_nearest_medoid, dist_to_medoid)
            return np.array(medoids)
        else:
            raise ValueError(f"Unknown init method '{self.init}'")
