import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import StandardScaler

def initialize_kmeans_plusplus(X, k, rng):
    n_samples, n_features = X.shape
    centroids = np.empty((k, n_features), dtype=X.dtype)

    # Step 1: Randomly choose the first centroid using the provided rng
    centroids[0] = X[rng.integers(n_samples)]

    # Step 2 and 3: Choose the remaining centroids
    for c in range(1, k):
        squared_diff = np.square(X[:, np.newaxis, :] - centroids[np.newaxis, :c, :])
        distances = np.min(squared_diff, axis=1)
        total_distances = np.sum(distances, axis=1)
        probabilities = total_distances / np.sum(total_distances)
        centroids[c] = X[rng.choice(n_samples, p=probabilities)]

    return centroids

class MiniBatchKMeans:
    def __init__(self, n_clusters, batch_size, n_iterations, dynamic_batching=True, new_lr=False,lr_decay = False, use_var=False, use_beta=False, jl="None", random_state=None):
        self.n_clusters = n_clusters
        self.batch_size = batch_size
        self.n_iterations = n_iterations
        self.centroids = None
        self.final_inertia = 0
        self.new_lr=new_lr #new learning rate flag
        self.random_state = random_state
        
    #needed for consistency with sklrean (for running experiments)   
    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key, value)
        return self

    def get_inertia(self,points, centers, normalized=True):
        nearest = np.argmin(euclidean_distances(points, centers), axis=1)
        inertia = np.linalg.norm(points-centers[nearest])**2 
        if normalized:
            inertia /=  points.shape[0]
        return inertia, nearest

    def fit(self, X, init_centroids=None):
        self.total_sampled = 0
        rng = np.random.default_rng(self.random_state)
        n_samples, n_features = X.shape
        self.centroids = init_centroids
        #assert(init_centroids is not None)
        if self.centroids is None:
            self.centroids = initialize_kmeans_plusplus(X, self.n_clusters, rng)
        self.counts = np.zeros(self.n_clusters)
        weight_history = np.zeros(self.n_clusters)
        
        for i in range(self.n_iterations):
            #print iteration
            # print(f"iteration {i} out of {self.n_iterations}")

            if self.batch_size is None:
                #use all data
                minibatch = X
            else:
                minibatch_indices = rng.choice(X.shape[0],self.batch_size,replace=False)
                
                minibatch = X[minibatch_indices]
            
            #assign points to clusters and compute inertia
            old_inertia, cached_nearest = self.get_inertia(minibatch, self.centroids)
            new_centroids = np.copy(self.centroids)
            
            cluster_points = {c_idx: [] for c_idx in range(self.n_clusters)}

            sums = np.zeros((self.n_clusters, n_features))
            cluster_sizes = np.zeros(self.n_clusters)

            for idx, x in enumerate(minibatch):
                centroid_idx = cached_nearest[idx]
                sums[centroid_idx] += x
                cluster_sizes[centroid_idx] += 1
                cluster_points[centroid_idx].append(x)

            
            for c_idx in range(self.n_clusters):

                if cluster_sizes[c_idx] == 0:
                    continue
                center_of_mass = sums[c_idx] / cluster_sizes[c_idx]
                if self.batch_size is None:
                    #use all data
                    new_centroids[c_idx] = center_of_mass
                    continue    
                learning_rate = 0
                if self.new_lr:
                    learning_rate =  np.sqrt(cluster_sizes[c_idx] / (self.batch_size)) #new learning rate   
                else:
                    weight_history[c_idx] += cluster_sizes[c_idx]
                    learning_rate = cluster_sizes[c_idx]/weight_history[c_idx]
                
                learning_rate = min(1, learning_rate)
                new_centroids[c_idx] = (1 - learning_rate) * self.centroids[c_idx] + learning_rate * center_of_mass

            self.centroids = new_centroids

        # Calculate final inertia snd cluster assignments for the entire dataset
        self.final_inertia, self.labels_ = self.get_inertia(X,self.centroids)

    def predict(self, X):
        return np.argmin(euclidean_distances(X, self.centroids), axis=1)