import warnings
from typing import Optional, Tuple

import numpy as np
from utils.random_utils import CryptoRandomGenerator


class QuantizedDistribution:
    """
    Represents a probability distribution quantized to a fixed number of bins.
    In such a setting, all operations are precisely defined and not subject to rounding errors.
    """

    def __init__(self, probs: np.array, num_bins: int, smooth: bool = False):
        assert len(probs.shape) in [1, 2]
        if num_bins < np.iinfo(np.uint8).max:
            dtype = np.uint8
        elif num_bins < np.iinfo(np.uint16).max:
            dtype = np.uint16
        elif num_bins < np.iinfo(np.uint32).max:
            dtype = np.uint32
        elif num_bins < np.iinfo(np.uint64).max:
            dtype = np.uint64
        else:
            raise ValueError("The number of bins is too large")
        if smooth:
            assert num_bins >= probs.shape[-1]
            unif = np.full(
                probs.shape, fill_value=1.0 / probs.shape[-1], dtype=np.float64
            )
            eps = probs.shape[-1] / num_bins
            if eps > 1e-3:
                warnings.warn(
                    "The distribution may be oversmoothed. You might want to increase the quantization steps."
                )
            probs = probs.astype(np.float64) * (1 - eps) + unif * eps

        cdf = np.cumsum(probs.astype(np.float64), axis=-1)[..., :-1]
        cdf = np.clip(cdf * num_bins, 0, num_bins).astype(dtype)
        frequencies = np.diff(cdf, axis=-1, prepend=dtype(0), append=dtype(num_bins))
        self.cdf = cdf
        self.frequencies = frequencies  # array can change in place
        self.num_bins = num_bins

    def get_cdf(self) -> np.array:
        """Returns the unnormalized c.d.f./cumsum as an array of integers."""
        return self.cdf

    def get_frequencies(self) -> np.array:
        """Returns the unnormalized probability distribution as an array of integers."""
        return self.frequencies.astype(np.uint32)

    def get_num_bins(self) -> int:
        """Returns the number of bins of the probability distribution."""
        return self.num_bins

    def sample(
        self, rng: CryptoRandomGenerator, batch_idx: Optional[int] = None
    ) -> int:
        """Sample an index from the probability distribution."""
        return self.decode(rng.randbelow(self.num_bins), batch_idx)

    def decode(self, t: int, batch_idx: Optional[int] = None) -> int:
        """Given a pointer t in [0, num_bins), finds the corresponding index."""
        assert t >= 0 and t < self.num_bins
        if len(self.cdf.shape) == 1:
            assert batch_idx is None
            cdf = self.cdf
        else:
            # Batched mode
            assert batch_idx is not None
            cdf = self.cdf[batch_idx]
        l = 0
        r = cdf.shape[-1]

        # Binary search
        while l < r:
            c = l + (r - l) // 2
            if t >= cdf[c]:
                l = c + 1
            else:
                r = c
        return r

    def get_interval(
        self, idx: int, batch_idx: Optional[int] = None
    ) -> Tuple[int, int]:
        """
        Returns the interval [l, r) corresponding to the provided index.
        Note that, if l == r, the element has probability 0 and cannot be represented.
        """
        if len(self.cdf.shape) == 1:
            assert batch_idx is None
            cdf = self.cdf
        else:
            # Batched mode
            assert batch_idx is not None
            cdf = self.cdf[batch_idx]

        if idx == 0:
            l = 0  # Inclusive
        else:
            l = cdf[idx - 1].item()  # Inclusive

        if idx < cdf.shape[-1]:
            r = cdf[idx].item()  # Exclusive
        else:
            r = self.num_bins  # Exclusive

        return l, r
