import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import StandardScaler

class MiniBatchKMeans:
    def __init__(self, n_clusters, batch_size, n_iterations, dynamic_batching=True, new_lr=False,lr_decay = False, use_var=False, use_beta=False, jl="None", random_state=None):
        self.n_clusters = n_clusters
        self.batch_size = batch_size
        self.n_iterations = n_iterations
        self.centroids = None
        self.final_inertia = 0
        self.new_lr=new_lr #new learning rate flag
        self.random_state = random_state
        
    #needed for consistency with sklrean (for running experiments)   
    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key, value)
        return self

    def get_inertia(self,points, centers, normalized=True):
        nearest = np.argmin(euclidean_distances(points, centers), axis=1)
        inertia = np.linalg.norm(points-centers[nearest])**2 
        if normalized:
            inertia /=  points.shape[0]
        return inertia, nearest

    def fit(self, X, init_centroids=None):
        self.total_sampled = 0
        rng = np.random.default_rng(self.random_state)
        n_samples, n_features = X.shape
        self.centroids = init_centroids
        assert(init_centroids is not None)
        self.counts = np.zeros(self.n_clusters)
        weight_history = np.zeros(self.n_clusters)
        
        for i in range(self.n_iterations):
            #print iteration
            # print(f"iteration {i} out of {self.n_iterations}")


            minibatch_indices = rng.choice(X.shape[0],self.batch_size,replace=False)
            
            minibatch = X[minibatch_indices]
            
            #assign points to clusters and compute inertia
            old_inertia, cached_nearest = self.get_inertia(minibatch, self.centroids)
            new_centroids = np.copy(self.centroids)
            
            cluster_points = {c_idx: [] for c_idx in range(self.n_clusters)}

            sums = np.zeros((self.n_clusters, n_features))
            cluster_sizes = np.zeros(self.n_clusters)

            for idx, x in enumerate(minibatch):
                centroid_idx = cached_nearest[idx]
                sums[centroid_idx] += x
                cluster_sizes[centroid_idx] += 1
                cluster_points[centroid_idx].append(x)

            
            for c_idx in range(self.n_clusters):
                if cluster_sizes[c_idx] == 0:
                    continue
                center_of_mass = sums[c_idx] / cluster_sizes[c_idx]
                learning_rate = 0
                if self.new_lr:
                    learning_rate =  np.sqrt(cluster_sizes[c_idx] / (self.batch_size)) #new learning rate   
                else:
                    weight_history[c_idx] += cluster_sizes[c_idx]
                    learning_rate = cluster_sizes[c_idx]/weight_history[c_idx]
                
                learning_rate = min(1, learning_rate)
                new_centroids[c_idx] = (1 - learning_rate) * self.centroids[c_idx] + learning_rate * center_of_mass

            self.centroids = new_centroids

        # Calculate final inertia snd cluster assignments for the entire dataset
        self.final_inertia, self.labels_ = self.get_inertia(X,self.centroids)

    def predict(self, X):
        return np.argmin(euclidean_distances(X, self.centroids), axis=1)

