import numpy as np
import warnings
from sklearn.cluster import KMeans

class BaseGP:
    def __init__(self, max_cluster_size=None, weights=None):
        self.max_cluster_size = max_cluster_size
        self.weights = weights

    def clustering(self, X, s):
        # should return centroids, labels, inertia
        raise NotImplementedError


    def _single_stochastic_run(self, X, s, weights, max_cluster_size, top_k, random_seed):
        """
        A single run of the stochastic clustering, isolated for parallel execution.
        """
        np.random.seed(random_seed)

        n_samples = X.shape[0]

        if weights is None:
            current_weights = np.ones((n_samples, 1))
        else:
            current_weights = np.array(weights, dtype=float).reshape(-1, 1)

        current_centroids = X.copy()
        cluster_members = [{i} for i in range(n_samples)]
        num_clusters = n_samples

        while num_clusters > s:
            C = current_centroids
            W = current_weights

            # 1. Cost Calculation
            dist_sq_matrix = np.sum(C ** 2, axis=1, keepdims=True) + \
                             np.sum(C ** 2, axis=1) - 2 * (C @ C.T)

            with warnings.catch_warnings():
                warnings.simplefilter("ignore", RuntimeWarning)
                weight_factor_matrix = (W * W.T) / (W + W.T)

            cost_matrix = np.sqrt(np.maximum(2 * dist_sq_matrix * weight_factor_matrix, 0))

            # 2. Constraints
            np.fill_diagonal(cost_matrix, np.inf)
            mask = np.tril(np.ones_like(cost_matrix, dtype=bool))
            cost_matrix[mask] = np.inf

            if max_cluster_size is not None:
                sizes = np.array([len(m) for m in cluster_members])
                combined_sizes = sizes[:, None] + sizes[None, :]
                cost_matrix[combined_sizes > max_cluster_size] = np.inf

            # 3. Guided Stochastic Selection
            flat_costs = cost_matrix.ravel()
            valid_merges_count = np.sum(~np.isinf(flat_costs))

            if valid_merges_count == 0:
                return None  # Deadlock

            current_k = min(top_k, valid_merges_count)

            if current_k == 1:
                chosen_flat_idx = np.argmin(flat_costs)
            else:
                partitioned_indices = np.argpartition(flat_costs, current_k - 1)[:current_k]
                candidate_costs = flat_costs[partitioned_indices]

                # Softmax selection
                min_c = np.min(candidate_costs)
                relative_costs = candidate_costs - min_c
                scale = np.mean(relative_costs)

                if scale < 1e-9:
                    probs = np.ones(current_k) / current_k
                else:
                    weights_prob = np.exp(-relative_costs / scale)
                    probs = weights_prob / np.sum(weights_prob)

                chosen_flat_idx = np.random.choice(partitioned_indices, p=probs)

            i, j = np.unravel_index(chosen_flat_idx, cost_matrix.shape)
            if i > j: i, j = j, i

            # 4. Update
            c1, c2 = current_centroids[i], current_centroids[j]
            w1, w2 = current_weights[i, 0], current_weights[j, 0]
            new_w = w1 + w2
            new_c = (c1 * w1 + c2 * w2) / new_w
            new_members = cluster_members[i].union(cluster_members[j])

            current_centroids = np.delete(current_centroids, j, axis=0)
            current_weights = np.delete(current_weights, j, axis=0)
            cluster_members.pop(j)

            current_centroids[i] = new_c
            current_weights[i] = new_w
            cluster_members[i] = new_members

            num_clusters -= 1

        # --- Final Evaluation ---
        final_labels = np.zeros(n_samples, dtype=int)
        for c_idx, members in enumerate(cluster_members):
            for s_idx in members:
                final_labels[s_idx] = c_idx

        full_w = np.ones(n_samples) if weights is None else np.array(weights)
        assigned_c = current_centroids[final_labels]
        sq_dists = np.sum((X - assigned_c) ** 2, axis=1)
        total_inertia = np.sum(full_w * sq_dists)

        kmeans = KMeans(n_clusters=s, init=current_centroids, max_iter=10 ** 4, tol=10 ** -5).fit(X, sample_weight=weights)
        current_centroids, final_labels, total_inertia = kmeans.cluster_centers_, kmeans.labels_, kmeans.inertia_

        return (current_centroids, final_labels, total_inertia)

    def _kmeans_from_distance_matrix(self, D_sq, n_clusters, init_labels, sample_weight=None, max_iter=10000, tol=1e-5):
        """
        Performs K-Means clustering using a pairwise distance matrix.

        Parameters:
        - D: (n_samples, n_samples) Symmetric matrix of Euclidean distances.
        - n_clusters: int, number of clusters.
        - init_labels: (n_samples,) array of initial integer labels (0 to n_clusters-1).
        - sample_weight: (n_samples,) array of weights (optional).
        - max_iter: int, maximum number of iterations.
        - tol: float, tolerance for convergence (relative change in inertia).

        Returns:
        - cluster_centers_: None (Cannot be computed explicitly without coordinates).
        - labels_: (n_samples,) Predicted labels.
        - inertia_: float, Sum of squared distances of samples to their closest cluster center.
        """
        n_samples = D_sq.shape[0]

        # Handle weights
        if sample_weight is None:
            sample_weight = np.ones(n_samples)
        else:
            sample_weight = np.asarray(sample_weight)

        # Initialize
        labels = np.array(init_labels, dtype=int)
        inertia = np.inf

        # Pre-allocate distance matrix (Samples x Clusters)
        dist_to_centroids = np.zeros((n_samples, n_clusters))

        for iteration in range(max_iter):
            prev_labels = labels.copy()
            prev_inertia = inertia

            # --- Step 1: Update Implicit Centroids & Calculate Distances ---
            for k in range(n_clusters):
                # Indices of points currently in cluster k
                mask = (labels == k)

                # Handle empty clusters (standard KMeans strategy: skip or re-init; here we skip)
                if not np.any(mask):
                    dist_to_centroids[:, k] = np.inf
                    continue

                w_k = sample_weight[mask]
                W_k = np.sum(w_k)  # Total weight of cluster

                # Term 1: Weighted average squared distance from every point i to points in C_k
                # Shape: (n_samples, subset) @ (subset,) -> (n_samples,)
                term1 = np.divide((D_sq[:, mask] @ w_k), (W_k))

                # Term 2: Weighted internal variance of the cluster (constant for the cluster)
                # Shape: (subset,) @ (subset, subset) @ (subset,) -> scalar
                # We subset D_sq to only the rows/cols belonging to the cluster
                D_sq_subset = D_sq[np.ix_(mask, mask)]
                term2 = np.divide((w_k @ D_sq_subset @ w_k), (2 * (W_k ** 2)))

                # Combine to get squared distance to the implicit centroid
                dist_to_centroids[:, k] = term1 - term2

            # --- Step 2: Assignment ---
            # Clip negative values that might occur due to floating point errors
            dist_to_centroids = np.maximum(dist_to_centroids, 0)

            labels = np.argmin(dist_to_centroids, axis=1)

            # --- Step 3: Calculate Inertia ---
            # Inertia is sum of squared distances to the nearest centroid
            min_dists = np.min(dist_to_centroids, axis=1)
            inertia = np.sum(sample_weight * min_dists)

            # --- Step 4: Convergence Check ---
            # Check label stability
            if np.array_equal(labels, prev_labels):
                break

            # Check inertia tolerance
            if prev_inertia != np.inf:
                change = np.abs(prev_inertia - inertia)
                if change < tol:
                    break

        # Cluster centers cannot be returned as coordinates (F dims) because input was only D (N dims).
        # We return None to maintain the 3-element tuple structure requested.
        cluster_centers_ = None

        return cluster_centers_, labels, inertia