"""K-means clustering."""
import dataclasses
import os

import h5py
import numpy as np
import tensorflow as tf

from em.util import hdf5_util

###############################################################################


@dataclasses.dataclass
class KMeans:
    # shape = [n_components, n_features]
    cluster_centers: np.ndarray

    def create_coeffs(self, X: np.ndarray, batch_size: int = 64) -> np.ndarray:
        # X.shape = [n_examples, n_features]
        n_examples = X.shape[0]

        coeffs = []

        start_index = 0
        while start_index < n_examples:
            batch = X[start_index : start_index + batch_size]

            # shape = [batch_size, n_components, n_features]
            diffs = batch[:, None, :] - self.cluster_centers[None, :, :]

            # shape = [batch_size, n_components]
            dists = np.sum(np.square(diffs), axis=-1)
            cluster_ids = np.argmin(dists, axis=-1)

            batch_coeffs = -1e9 * np.ones([batch.shape[0], self.cluster_centers.shape[0]], dtype=np.float32)
            batch_coeffs[:, cluster_ids] = -dists[:, cluster_ids]
            coeffs.append(batch_coeffs)

            start_index += batch_size
        
        return np.concatenate(coeffs, axis=0)

    def save(self, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "w") as file:
            data_grp = file.create_group('data')
            hdf5_util.save_h5_ds(data_grp, 'cluster_centers', self.cluster_centers)

    @classmethod
    def load(cls, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            return cls(
                cluster_centers=hdf5_util.load_h5_ds(f['data/cluster_centers']),
            )
