from typing import Callable, Optional
import numpy as np
import torch
from ._grids import EquidistantGrid, RatedGrid, Grid, rtn_grid
from nn_compression._core import DeepCabacRdQuantiser as _DeepCabacRdQuantiser
from nn_compression.quantisation._interfaces import (
    Quantiser,
    Quantised,
)
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.cluster import KMeans
from data_utils.arrays import make_numpy


class VectorQuantiser(Quantiser):
    """Quantiser that uses the KMeans implementation to quantise the input tensor to a
    grid with a codebook. Be careful with the input parameters, as VQ is known
    to be quite slow for large datasets, especially with large codebook sizes (> 2**5 bits)

    Usage example:
    >>> import torch
    >>> from nn_compression.quantisation import VectorQuantiser
    >>> x = torch.tensor([1.0, 2.0, 1.0, 2.0, 8.0, 9.0, 9.0, 8.0])
    >>> q = VectorQuantiser(1)
    >>> q.find_params(x)
    >>> xq = q.quantise(x).x
    >>> assert torch.allclose(
    >>>     xq, torch.tensor([1.5, 1.5, 1.5, 1.5, 8.5, 8.5, 8.5, 8.5])
    >>> )
    """

    def __init__(self, bits) -> None:
        """Initialises the quantiser with the number of bits to quantise the input tensor to.
        Args:
            bits: Uses 2**bits codewords to quantise the input tensor to."""
        super().__init__(bits)
        self.K = 2**bits
        self.encoder = None

    def ready(self):
        return self.encoder is not None

    def find_params(self, x: torch.Tensor, weight=False) -> None:
        """Determines the encoder parameters for the input tensor.
        This involves running k-means clustering on the input tensor, which might take some time,
        especially if bits and/or the input tensor are large.

        As Vector Quantisation is especially useful for quantising vector-valued tensors, the input tensor
        is expected to have a shape of (n, d), where n is the number of vectors and d is the dimension of each vector.
        Alternatively, the input tensor can be a 1D tensor.
        """
        self.encoder = KMeans(n_clusters=self.K)
        x_ = make_numpy(x)
        if len(x.shape) == 1:
            x_ = x_.reshape(-1, 1)
        self.encoder.fit(x_)

    def quantise(self, x: torch.Tensor) -> Quantised:
        """Quantize the input tensor using the current grid."""
        assert self.encoder is not None
        xnp = make_numpy(x)
        if len(x.shape) == 1:
            xnp = xnp.reshape(-1, 1)
        quant_idx = self.encoder.predict(xnp)
        xq = torch.tensor(self.encoder.cluster_centers_[quant_idx]).reshape(x.shape)
        # TODO: Return cluster centers as grid
        grid = (
            RatedGrid.from_tensor(xq)
            if len(x.shape) == 1
            else RatedGrid.from_vectors(xq)
        )
        return Quantised(xq, grid)


class _VectorQuantizerKbins(VectorQuantiser):
    """Quantiser that uses the Kbins implementation to quantise the input tensor.
    This some strange behaviour, such as that [1,2,1,2, 8, 9, 8, 9]
    is quantised to [3, 3, 3, 3, 7, 7, 7, 7], as the quantisation is just the bin centroid,
    which is increased if the clusters have a higher distance to each other."""

    def find_params(self, x: torch.Tensor, weight=False) -> None:
        nbins = 2**self.bits
        encoder = KBinsDiscretizer(
            n_bins=nbins,
            encode="ordinal",
            strategy="kmeans",
        )
        x_ = make_numpy(x)
        if len(x.shape) == 1:
            x_ = x_.reshape(-1, 1)
        encoder.fit(x_)
        self.encoder = encoder

    def quantise(self, x: torch.Tensor) -> Quantised:
        """Quantize the input tensor using the current grid."""
        assert self.encoder is not None
        quant_idx = self.encoder.transform(make_numpy(x).flatten().reshape(-1, 1))
        xq = torch.tensor(self.encoder.inverse_transform(quant_idx)).reshape(x.shape)
        return Quantised(xq, RatedGrid.from_tensor(xq))


class PerRowGridQuantiser(Quantiser):
    def __init__(
        self,
        quantiser_fn: Optional[Callable] = None,
        quantisers: Optional[list[Quantiser]] = None,
    ) -> None:
        super().__init__(None)
        self.quantisers = quantisers
        self.quantiser_fn = quantiser_fn

    def ready(self):
        return self.quantisers is not None and all(q.ready() for q in self.quantisers)

    def find_params(self, w, weight=True):
        if self.quantiser_fn is not None:
            self.quantisers = []
        # weight is transposed internally
        assert self.quantisers is not None
        for row in range(w.shape[1]):
            if self.quantiser_fn is None:
                quant = self.quantisers[row]
            else:
                quant = self.quantiser_fn()
                self.quantisers.append(quant)
            # preserve shape
            quant.find_params(w[:, row : (row + 1)], weight=weight)

    def quantise(self, x: torch.Tensor, row: int | None = None) -> Quantised:
        if row is None:
            raise ValueError("Row must be specified.")
        assert self.quantisers is not None
        return self.quantisers[row].quantise(x)

    def quantise_with_uncertainty(
        self, x: torch.Tensor, posterior_variance: torch.Tensor, col: int | None = None
    ) -> Quantised:
        if col is None:
            raise ValueError("Row must be specified.")
        assert self.quantisers is not None
        return self.quantisers[col].quantise_with_uncertainty(
            x, posterior_variance, col=col
        )


class FixedGridQuantiser(Quantiser):
    """Quantiser that uses a fixed grid to quantise the input tensor."""

    def __init__(self, gridpoints: torch.Tensor | Grid) -> None:
        """"""
        gridpoints = (
            gridpoints
            if isinstance(gridpoints, torch.Tensor)
            else gridpoints.gridpoints
        )
        assert len(gridpoints) > 0
        grid_bits = torch.ceil(torch.tensor(gridpoints.unique().numel())).item()
        self.current_grid_idx = 0
        self._gridpoints = gridpoints
        super().__init__(grid_bits)

    def ready(self) -> bool:
        return True

    def quantise(self, x: torch.Tensor, col=None) -> Quantised:
        """Quantize the input tensor using the current grid."""
        val = rtn_grid(x, self._gridpoints.flatten().sort()[0])
        return Quantised(val, RatedGrid.from_tensor(val))

    def reset(self) -> None:
        self.current_grid_idx = 0

    def grid(self) -> torch.Tensor:
        return self._gridpoints


def _determine_grid(x: torch.Tensor, nbins: int, sym: bool):
    x = x.flatten()
    if sym:
        absmax = x.abs().max().item()
        max = absmax
        min = -absmax
    else:
        max = x.max().item()
        min = x.min().item()

    step = (max - min) / (nbins - 1)
    if sym:
        zeropoint_shift = -min - int(np.round(nbins / 2)) * step
    else:
        zeropoint_shift = 0
    min = min + zeropoint_shift
    max = max + zeropoint_shift

    return min, max, step, zeropoint_shift


class AffineGridQuantiser(Quantiser):
    """Quantises a tensor using an affine grid. The grid is defined by the maximum and minimum values of the tensor.
    The grid always contains 2**bits gridpoints.

    If symmetric=True, the grid is symmetric around, else it spans exactly from min to max with (max - min)/2**bits step size.
    """

    def __init__(
        self, bits, symmetric: bool = True, per_row_grid: bool = False
    ) -> None:
        self._bits = bits
        self.grid = None
        self.sym = symmetric
        self.min = None
        self.max = None
        self.nbins = int(np.round(2**bits))
        self.per_row_grid = per_row_grid

    def ready(self) -> bool:
        return self.min is not None

    def find_params(self, x, weight=False):
        if self.per_row_grid:
            self.step = []
            self.min = []
            self.max = []
            for row in range(x.shape[0]):
                min, max, step, _ = _determine_grid(x[row, :], self.nbins, self.sym)
                self.step.append(step)
                self.min.append(min)
                self.max.append(max)
                self.grid = None
            self.step = torch.tensor(self.step)
            self.min = torch.tensor(self.min)
            self.max = torch.tensor(self.max)
        else:
            self.min, self.max, self.step, _ = _determine_grid(x, self.nbins, self.sym)
            self.step = torch.tensor(self.step)
            self.grid = EquidistantGrid(self.min, self.step.item(), self.nbins)

    def quantise(self, x: torch.Tensor, col: int | None = None) -> Quantised:
        if not self.ready():
            self.find_params(x)
        assert self.min is not None
        assert self.max is not None
        assert isinstance(self.step, torch.Tensor)
        min_idx = torch.round(self.min / self.step)
        max_idx = torch.round(self.max / self.step)
        q = torch.round(x / self.step)
        q = q.clamp(min_idx, max_idx)
        q = q * self.step
        if self.per_row_grid:
            return Quantised(q, None)  # type: ignore
        else:
            counts = torch.histc(q, min=self.min, max=self.max, bins=self.nbins)  # type: ignore
            return Quantised(q, RatedGrid(self.grid, counts))  # type: ignore

    def __repr__(self):
        return str(self)

    def __str__(self) -> str:
        return f"AffineGridQuantiser({self.bits})"


class RdQuantiserDeepCabac(Quantiser):
    def __init__(
        self, bits, lmbda: float, per_row_grid: bool = False, blocksize: int = 1
    ) -> None:
        super().__init__(bits)
        self.per_row_grid = per_row_grid

        self.blocksize = blocksize
        self.nbins = int(np.round(2**bits))
        self.lmbda = lmbda
        # if per_row_grid, use a dict
        self.delta = None
        self.min_idx = None
        self.max_idx = None
        self.quantiser = _DeepCabacRdQuantiser(lmbda, 1, 0, 0)

    def ready(self) -> bool:
        return self.quantiser is not None and self.delta is not None

    def _find_params_weight(self, x: torch.Tensor) -> tuple[int, int, float]:
        min, max, delta, _ = _determine_grid(x, self.nbins, sym=True)
        min_idx = int(np.round(min / delta))
        max_idx = min_idx + (self.nbins - 1)
        assert max_idx == int(
            np.round(max / delta)
        ), f"{max_idx} != {int(np.round(max / delta))}"
        return min_idx, max_idx, delta

    def find_params(self, x: torch.Tensor, weight=False) -> None:
        if self.per_row_grid:
            self.delta = []
            self.min_idx = []
            self.max_idx = []
            maxrows = x.shape[0]
            for row in range(0, maxrows, self.blocksize):
                endrow = min(row + self.blocksize, maxrows)
                block = x[row:endrow, :]
                min_idx, max_idx, delta = self._find_params_weight(block.flatten())
                # choosing grid for a block instead of the whole tensor
                for _ in range(block.shape[0]):
                    self.delta.append(delta)
                    self.min_idx.append(min_idx)
                    self.max_idx.append(max_idx)
            self.delta = np.array(self.delta)
            self.min_idx = np.array(self.min_idx)
            self.max_idx = np.array(self.max_idx)

        else:
            self.min_idx, self.max_idx, self.delta = self._find_params_weight(x)
            self.quantiser.set_grid(self.min_idx, self.max_idx, self.delta)

    def quantise_with_uncertainty(
        self, w: torch.Tensor, posterior_variance: torch.Tensor, col: int | None = None
    ) -> Quantised:
        assert self.quantiser is not None
        w = w.detach().numpy().flatten()
        pv = posterior_variance.detach().numpy().flatten()
        if pv.size > 1:
            raise ValueError("Only scalar posterior variance supported at the moment.")

        if self.per_row_grid:
            pv = np.broadcast_to(pv, w.shape)
            if not (self.min_idx[0] == self.min_idx).all():  # type: ignore
                raise ValueError("Min idx mismatch.")
            if not (self.max_idx[0] == self.max_idx).all():  # type: ignore
                raise ValueError("max idx mismatch.")
            wq_idx = self.quantiser.quantize(
                w, pv[0], self.delta, self.min_idx[0], self.max_idx[0]  # type: ignore
            )
        else:
            wq_idx = self.quantiser.quantize(w, pv[0])
        wq_idx = torch.tensor(wq_idx)
        return Quantised((wq_idx * self.delta).to(torch.float32), None)

    def quantise(self, x: torch.Tensor, col=None):
        if col is not None:
            raise ValueError("Per row grid not implemented.")
        return self.quantise_with_uncertainty(x, torch.ones_like(x))
