# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .em import EM, EmptyClusterResolveError


class PQ(EM):
    """
    Quantizes the layer weights W with the standard Product Quantization
    technique. This learns a codebook of codewords or centroids of size
    block_size from W. For further reference on using PQ to quantize
    neural networks, see "And the Bit Goes Down: Revisiting the Quantization
    of Neural Networks", Stock et al., ICLR 2020.

    PQ is performed in two steps:
    (1) The matrix W (weights or fully-connected or convolutional layer)
        is reshaped to (block_size, -1).
            - If W is fully-connected (2D), its columns are split into
              blocks of size block_size.
            - If W is convolutional (4D), its filters are split along the
              spatial dimension.
    (2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix.

    Args:
        - W: weight matrix to quantize of size (in_features x out_features)
        - block_size: size of the blocks (subvectors)
        - n_centroids: number of centroids
        - n_iter: number of k-means iterations
        - eps: for cluster reassignment when an empty cluster is found
        - max_tentatives for cluster reassignment when an empty cluster is found
        - verbose: print information after each iteration

    Remarks:
        - block_size be compatible with the shape of W
    """

    def __init__(
        self,
        W,
        block_size,
        n_centroids=256,
        n_iter=20,
        eps=1e-6,
        max_tentatives=30,
        verbose=True,
    ):
        self.block_size = block_size
        W_reshaped = self._reshape(W)
        super(PQ, self).__init__(
            W_reshaped,
            n_centroids=n_centroids,
            n_iter=n_iter,
            eps=eps,
            max_tentatives=max_tentatives,
            verbose=verbose,
        )

    def _reshape(self, W):
        """
        Reshapes the matrix W as expained in step (1).
        """

        # fully connected: by convention the weight has size out_features x in_features
        if len(W.size()) == 2:
            self.out_features, self.in_features = W.size()
            assert (
                self.in_features % self.block_size == 0
            ), "Linear: n_blocks must be a multiple of in_features"
            return (
                W.reshape(self.out_features, -1, self.block_size)
                .permute(2, 1, 0)
                .flatten(1, 2)
            )

        # convolutional: we reshape along the spatial dimension
        elif len(W.size()) == 4:
            self.out_channels, self.in_channels, self.k_h, self.k_w = W.size()
            assert (
                self.in_channels * self.k_h * self.k_w
            ) % self.block_size == 0, (
                "Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w"
            )
            return (
                W.reshape(self.out_channels, -1, self.block_size)
                .permute(2, 1, 0)
                .flatten(1, 2)
            )
        # not implemented
        else:
            raise NotImplementedError(W.size())

    def encode(self):
        """
        Performs self.n_iter EM steps.
        """

        self.initialize_centroids()
        for i in range(self.n_iter):
            try:
                self.step(i)
            except EmptyClusterResolveError:
                break

    def decode(self):
        """
        Returns the encoded full weight matrix. Must be called after
        the encode function.
        """

        # fully connected case
        if "k_h" not in self.__dict__:
            return (
                self.centroids[self.assignments]
                .reshape(-1, self.out_features, self.block_size)
                .permute(1, 0, 2)
                .flatten(1, 2)
            )

        # convolutional case
        else:
            return (
                self.centroids[self.assignments]
                .reshape(-1, self.out_channels, self.block_size)
                .permute(1, 0, 2)
                .reshape(self.out_channels, self.in_channels, self.k_h, self.k_w)
            )
