"""K-means clustering."""
import dataclasses
import gc
import os
from typing import Tuple

import h5py
import numpy as np
import torch
from tqdm import tqdm

from npeff_torch.util import hdf5_utils


###############################################################################


# deprecated
@dataclasses.dataclass
class KmeansClustering:
    # shape = [n_clusters, d_data]
    centroids: np.ndarray

    # shape = [n_samples], dtype=np.int64
    cluster_assignments: np.ndarray

    # shape = [n_samples]
    centroid_distances: np.ndarray

    #######################################################

    @property
    def n_clusters(self) -> int:
        return int(self.centroids.shape[0])

    #######################################################

    def save(self, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "w") as f:
            hdf5_utils.save_h5_ds(f, 'data/centroids', self.centroids)
            hdf5_utils.save_h5_ds(f, 'data/cluster_assignments', self.cluster_assignments)
            hdf5_utils.save_h5_ds(f, 'data/centroid_distances', self.centroid_distances)

    @classmethod
    def load(cls, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            return cls(
                centroids=hdf5_utils.load_h5_ds(f['data/centroids']),
                cluster_assignments=hdf5_utils.load_h5_ds(f['data/cluster_assignments']),
                centroid_distances=hdf5_utils.load_h5_ds(f['data/centroid_distances']),
            )


###############################################################################


@dataclasses.dataclass
class KmeansClusteringTorch:
    # shape = [n_clusters, d_data]
    centroids: torch.Tensor

    # shape = [n_samples], dtype=torch.int64
    cluster_assignments: torch.Tensor

    # shape = [n_samples]
    centroid_distances: torch.Tensor

    #######################################################

    def __post_init__(self):
        self.components_are_normalized = False
        
    #######################################################

    @property
    def n_examples(self):
        return self.centroid_distances.shape[0]
    
    @property
    def n_clusters(self):
        return self.centroids.shape[0]
    
    @property
    def n_components(self):
        return self.centroids.shape[0]
    
    def to(self, device: torch.device) -> 'KmeansClusteringTorch':
        self.centroids = self.centroids.to(device)
        self.cluster_assignments = self.cluster_assignments.to(device)
        self.centroid_distances = self.centroid_distances.to(device)
        return self

    #######################################################

    def normalize_reduced_components_to_unit_norm_(self, eps: float = 1e-12):
        """Normalizes the centroids to have unit L2 norm.

        NOTE: This does not affect the centroid_distances or anything else.
        """
        self.centroids = torch.nn.functional.normalize(self.centroids, dim=-1, eps=eps)
        self.components_are_normalized = True

    #######################################################

    def compute_clusters(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes the clusters that the samples belong to.

        Args:
            x: shape = [n_samples, d_data]

        Returns:
            cluster_assignments: shape = [n_samples], dtype=torch.int64
            centroid_distances: shape = [n_samples]
        """
        n_samples, _ = x.shape

        D_ij = torch.einsum('nd,kd->nk', x, self.centroids)
        D_ij *= -2.0
        D_ij += torch.einsum('nd,nd->n', x, x).view(n_samples, 1)
        D_ij += torch.einsum('kd,kd->k', self.centroids, self.centroids).view(1, self.n_clusters)

        cluster_assignments = D_ij.argmin(dim=1).type(torch.int64).view(-1)
        del D_ij
        centroid_distances = _compute_centroid_distances(x, self.centroids, cluster_assignments)

        return cluster_assignments, centroid_distances

    #######################################################

    # deprecated
    def numpy(self) -> KmeansClustering:
        """Creates a version of itself using numpy arrays instead of torch Tensors."""
        return KmeansClustering(
            centroids=self.centroids.detach().cpu().numpy(),
            cluster_assignments=self.cluster_assignments.detach().cpu().numpy(),
            centroid_distances=self.centroid_distances.detach().cpu().numpy(),
        )

    #######################################################

    def save(self, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "w") as f:
            hdf5_utils.save_h5_ds(f, 'data/centroids', self.centroids.detach().cpu().numpy())
            hdf5_utils.save_h5_ds(f, 'data/cluster_assignments', self.cluster_assignments.detach().cpu().numpy())
            hdf5_utils.save_h5_ds(f, 'data/centroid_distances', self.centroid_distances.detach().cpu().numpy())

    @classmethod
    def load(cls, filepath: str):
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            return cls(
                centroids=torch.from_numpy(hdf5_utils.load_h5_ds(f['data/centroids'])),
                cluster_assignments=torch.from_numpy(hdf5_utils.load_h5_ds(f['data/cluster_assignments'])),
                centroid_distances=torch.from_numpy(hdf5_utils.load_h5_ds(f['data/centroid_distances'])),
            )

    @classmethod
    def load_n_clusters(cls, filepath: str) -> int:
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            return f['data/centroids'].shape[0]

    @classmethod
    def load_cluster_assignments(cls, filepath: str) -> torch.Tensor:
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            return torch.from_numpy(hdf5_utils.load_h5_ds(f['data/cluster_assignments']))

    @classmethod
    def load_centroid_distances(cls, filepath: str) -> torch.Tensor:
        with h5py.File(os.path.expanduser(filepath), "r") as f:
            return torch.from_numpy(hdf5_utils.load_h5_ds(f['data/centroid_distances']))


###############################################################################



# Allow initial centroids to be passed
# Check for convergence (based on cl not changing), maybe every certain number of steps


# https://www.kernel-operations.io/keops/_auto_tutorials/kmeans/plot_kmeans_torch.html
def compute_kmeans_lloyds(
    x: torch.Tensor,
    n_clusters: int,
    n_iterations: int,
    *,
    verbose: bool = True,
) -> KmeansClusteringTorch:
    # x.shape = [n_samples, d_data]
    n_samples, d_x = x.shape

    # Simplistic initialization for the centroids
    centroids = x[:n_clusters, :].clone()

    xx_i = torch.einsum('nd,nd->n', x, x).view(n_samples, 1)

    cl_old = None

    for step in tqdm(range(n_iterations)) if verbose else range(n_iterations):
        # E step: assign points to the closest cluster -------------------------


        # D_ij = ((x_i - c_j) ** 2).sum(-1)  # (N, K) symbolic squared distances
        cc_j = torch.einsum('kd,kd->k', centroids, centroids).view(1, n_clusters)
        xc_ij = torch.einsum('nd,kd->nk', x, centroids)
        D_ij = xx_i + cc_j - 2 * xc_ij


        cl = D_ij.argmin(dim=1).type(torch.int64).view(-1)  # Points -> Nearest cluster

        # M step: update the centroids to the normalized cluster average: ------
        # Compute the sum of points per cluster:
        # centroids.zero_()



        # centroids.scatter_add_(0, cl[:, None].repeat(1, d_x), x)
        mask = torch.nn.functional.one_hot(cl, num_classes=n_clusters).type_as(centroids)
        centroids = torch.einsum('nk,nd->kd', mask, x)



        # Divide by the number of points per cluster:
        Ncl = torch.bincount(cl, minlength=n_clusters).type_as(centroids).view(n_clusters, 1)
        Ncl.clamp_(min=1.0)  # Make sure we don't divide by zero.
        centroids /= Ncl  # in-place division to compute the average

        # Check convergence to stop early.
        if cl_old is not None and torch.all(cl_old == cl):
            break

        cl_old = cl

        if verbose:
            if torch.cuda.is_available():
                torch.cuda.synchronize()

    return KmeansClusteringTorch(
        centroids=centroids,
        cluster_assignments=cl,
        centroid_distances=_compute_centroid_distances(x, centroids, cl),
    )


@torch.no_grad()
def compute_kmeans_lloyds2(
    x: torch.Tensor,
    n_clusters: int,
    n_iterations: int,
    *,
    batch_size: int = 4096,
    verbose: bool = True,
) -> KmeansClusteringTorch:
    # Same as the original but adapted to use less memory so we can run on GPU.

    # x.shape = [n_samples, d_data]
    n_samples, d_x = x.shape

    # Simplistic initialization for the centroids
    centroids = x[:n_clusters, :].clone()

    xx_i = torch.einsum('nd,nd->n', x, x).view(n_samples, 1)

    cl_old = None

    for step in tqdm(range(n_iterations)) if verbose else range(n_iterations):
        # E step: assign points to the closest cluster -------------------------

        # D_ij = ((x_i - c_j) ** 2).sum(-1)  # (N, K) symbolic squared distances
        #
        # cc_j = torch.einsum('kd,kd->k', centroids, centroids).view(1, n_clusters)
        # xc_ij = torch.einsum('nd,kd->nk', x, centroids)
        # D_ij = xx_i + cc_j - 2 * xc_ij
        #
        D_ij = torch.einsum('nd,kd->nk', x, centroids)
        D_ij *= -2.0
        D_ij += xx_i
        D_ij += torch.einsum('kd,kd->k', centroids, centroids).view(1, n_clusters)

        cl = D_ij.argmin(dim=1).type(torch.int64).view(-1)  # Points -> Nearest cluster
        del D_ij

        # M step: update the centroids to the normalized cluster average: ------

        # Compute the sum of points per cluster:
        centroids.zero_()
        for i in range(0, n_samples, batch_size):
            mask = torch.nn.functional.one_hot(cl[i : i + batch_size], num_classes=n_clusters).type_as(centroids)
            centroids += torch.einsum('nk,nd->kd', mask, x[i : i + batch_size])

        # Divide by the number of points per cluster:
        Ncl = torch.bincount(cl, minlength=n_clusters).type_as(centroids).view(n_clusters, 1)
        Ncl.clamp_(min=1.0)  # Make sure we don't divide by zero.
        centroids /= Ncl  # in-place division to compute the average

        # Check convergence to stop early.
        if cl_old is not None and torch.all(cl_old == cl):
            break

        cl_old = cl

        if verbose:
            if torch.cuda.is_available():
                torch.cuda.synchronize()

    return KmeansClusteringTorch(
        centroids=centroids,
        cluster_assignments=cl,
        # Do this on the CPU.
        centroid_distances=_compute_centroid_distances(x.cpu(), centroids.cpu(), cl.cpu()),
    )


###############################################################################


def _compute_centroid_distances(x: torch.Tensor, centroids: torch.Tensor, cluster_assignments: torch.Tensor) -> torch.Tensor:
    # x.shape = [n_samples, d_data]
    # centroids.shape = [n_clusters, d_data]
    # cluster_assignments.shape = [n_samples], dtype=torch.int64
    #
    # ret.shape = [n_samples]
    return torch.linalg.vector_norm(x - centroids[cluster_assignments], dim=-1)
