# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""Arithmetic coder."""

import io
import math
import random
import typing as tp
import torch

from ..binary import BitPacker, BitUnpacker


def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
                               roundoff: float = 1e-8, min_range: int = 2,
                               check: bool = True) -> torch.Tensor:
    """Turn the given PDF into a quantized CDF that splits
    [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
    to the PDF.

    Args:
        pdf (torch.Tensor): probability distribution, shape should be `[N]`.
        total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
            during the coding process is `[0, 2 ** total_range_bits - 1]`.
        roundoff (float): will round the pdf up to that level to remove difference coming
        from e.g. evaluating the Language Model on different architectures.
        min_range (int): minimum range width. Should always be at least 2 for numerical
            stability. Use this to avoid pathological behavior is a value
            that is expected to be rare actually happens in real life.
        check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
    """
    pdf = pdf.detach()
    if roundoff:
        pdf = (pdf / roundoff).floor() * roundoff
    # interpolate with uniform distribution to achieve desired minimum probability.
    total_range = 2 ** total_range_bits
    cardinality = len(pdf)
    alpha = min_range * cardinality / total_range
    assert alpha <= 1, "you must reduce min_range"
    ranges = (((1 - alpha) * total_range) * pdf).floor().long()
    ranges += min_range
    quantized_cdf = torch.cumsum(ranges, dim=-1)
    if min_range < 2:
        raise ValueError("min_range must be at least 2.")
    if check:
        assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
        if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
            raise ValueError("You must increase your total_range_bits.")
    return quantized_cdf


class ArithmeticCoder:
    """ArithmeticCoder,
    Let us take a distribution `p` over `N` symbols, and assume we have a stream
    of random variables `s_t` sampled from `p`. Let us assume that we have a budget
    of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
    corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
    sequence `(s_t)` by doing the following:

    1) Initialize the current range to` [0 ** 2 B - 1]`.
    2) For each time step t, split the current range into contiguous chunks,
        one for each possible outcome, with size roughly proportional to `p`.
        For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
        would be `{[0, 2], [3, 3]}`.
    3) Select the chunk corresponding to `s_t`, and replace the current range with this.
    4) When done encoding all the values, just select any value remaining in the range.

    You will notice that this procedure can fail: for instance if at any point in time
    the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
    possible outcome. Intuitively, the more likely a value is, the less the range width
    will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
    coding scheme, likely outcomes would take less bits, and more of them can be coded
    with a fixed budget.

    In practice, we do not know `B` ahead of time, but we have a way to inject new bits
    when the current range decreases below a given limit (given by `total_range_bits`), without
    having to redo all the computations. If we encode mostly likely values, we will seldom
    need to inject new bits, but a single rare value can deplete our stock of entropy!

    In this explanation, we assumed that the distribution `p` was constant. In fact, the present
    code works for any sequence `(p_t)` possibly different for each timestep.
    We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
    the KL between the true distribution and `p_t`, the most efficient the coding will be.

    Args:
        fo (IO[bytes]): file-like object to which the bytes will be written to.
        total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
            Any time the current range width fall under this limit, new bits will
            be injected to rescale the initial range.
    """

    def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
        assert total_range_bits <= 30
        self.total_range_bits = total_range_bits
        self.packer = BitPacker(bits=1, fo=fo)  # we push single bits at a time.
        self.low: int = 0
        self.high: int = 0
        self.max_bit: int = -1
        self._dbg: tp.List[tp.Any] = []
        self._dbg2: tp.List[tp.Any] = []

    @property
    def delta(self) -> int:
        """Return the current range width."""
        return self.high - self.low + 1

    def _flush_common_prefix(self):
        # If self.low and self.high start with the sames bits,
        # those won't change anymore as we always just increase the range
        # by powers of 2, and we can flush them out to the bit stream.
        assert self.high >= self.low, (self.low, self.high)
        assert self.high < 2 ** (self.max_bit + 1)
        while self.max_bit >= 0:
            b1 = self.low >> self.max_bit
            b2 = self.high >> self.max_bit
            if b1 == b2:
                self.low -= (b1 << self.max_bit)
                self.high -= (b1 << self.max_bit)
                assert self.high >= self.low, (self.high, self.low, self.max_bit)
                assert self.low >= 0
                self.max_bit -= 1
                self.packer.push(b1)
            else:
                break

    def push(self, symbol: int, quantized_cdf: torch.Tensor):
        """Push the given symbol on the stream, flushing out bits
        if possible.

        Args:
            symbol (int): symbol to encode with the AC.
            quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
                to build this from your pdf estimate.
        """
        while self.delta < 2 ** self.total_range_bits:
            self.low *= 2
            self.high = self.high * 2 + 1
            self.max_bit += 1

        range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
        range_high = quantized_cdf[symbol].item() - 1
        effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
        effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
        assert self.low <= self.high
        self.high = self.low + effective_high
        self.low = self.low + effective_low
        assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
        self._dbg.append((self.low, self.high))
        self._dbg2.append((self.low, self.high))
        outs = self._flush_common_prefix()
        assert self.low <= self.high
        assert self.max_bit >= -1
        assert self.max_bit <= 61, self.max_bit
        return outs

    def flush(self):
        """Flush the remaining information to the stream.
        """
        while self.max_bit >= 0:
            b1 = (self.low >> self.max_bit) & 1
            self.packer.push(b1)
            self.max_bit -= 1
        self.packer.flush()


class ArithmeticDecoder:
    """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.

    Note that this must be called with **exactly** the same parameters and sequence
    of quantized cdf as the arithmetic encoder or the wrong values will be decoded.

    If the AC encoder current range is [L, H], with `L` and `H` having the some common
    prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
    For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
    `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
    for a specific sequence of symbols and a binary-search allows us to decode those symbols.
    At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
    and we will need to read new bits from the stream and repeat the process.

    """
    def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
        self.total_range_bits = total_range_bits
        self.low: int = 0
        self.high: int = 0
        self.current: int = 0
        self.max_bit: int = -1
        self.unpacker = BitUnpacker(bits=1, fo=fo)  # we pull single bits at a time.
        # Following is for debugging
        self._dbg: tp.List[tp.Any] = []
        self._dbg2: tp.List[tp.Any] = []
        self._last: tp.Any = None

    @property
    def delta(self) -> int:
        return self.high - self.low + 1

    def _flush_common_prefix(self):
        # Given the current range [L, H], if both have a common prefix,
        # we know we can remove it from our representation to avoid handling large numbers.
        while self.max_bit >= 0:
            b1 = self.low >> self.max_bit
            b2 = self.high >> self.max_bit
            if b1 == b2:
                self.low -= (b1 << self.max_bit)
                self.high -= (b1 << self.max_bit)
                self.current -= (b1 << self.max_bit)
                assert self.high >= self.low
                assert self.low >= 0
                self.max_bit -= 1
            else:
                break

    def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
        """Pull a symbol, reading as many bits from the stream as required.
        This returns `None` when the stream has been exhausted.

        Args:
            quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
                to build this from your pdf estimate. This must be **exatly**
                the same cdf as the one used at encoding time.
        """
        while self.delta < 2 ** self.total_range_bits:
            bit = self.unpacker.pull()
            if bit is None:
                return None
            self.low *= 2
            self.high = self.high * 2 + 1
            self.current = self.current * 2 + bit
            self.max_bit += 1

        def bin_search(low_idx: int, high_idx: int):
            # Binary search is not just for coding interviews :)
            if high_idx < low_idx:
                raise RuntimeError("Binary search failed")
            mid = (low_idx + high_idx) // 2
            range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
            range_high = quantized_cdf[mid].item() - 1
            effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
            effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
            low = effective_low + self.low
            high = effective_high + self.low
            if self.current >= low:
                if self.current <= high:
                    return (mid, low, high, self.current)
                else:
                    return bin_search(mid + 1, high_idx)
            else:
                return bin_search(low_idx, mid - 1)

        self._last = (self.low, self.high, self.current, self.max_bit)
        sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
        self._dbg.append((self.low, self.high, self.current))
        self._flush_common_prefix()
        self._dbg2.append((self.low, self.high, self.current))

        return sym


def test():
    torch.manual_seed(1234)
    random.seed(1234)
    for _ in range(4):
        pdfs = []
        cardinality = random.randrange(4000)
        steps = random.randrange(100, 500)
        fo = io.BytesIO()
        encoder = ArithmeticCoder(fo)
        symbols = []
        for step in range(steps):
            pdf = torch.softmax(torch.randn(cardinality), dim=0)
            pdfs.append(pdf)
            q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
            symbol = torch.multinomial(pdf, 1).item()
            symbols.append(symbol)
            encoder.push(symbol, q_cdf)
        encoder.flush()

        fo.seek(0)
        decoder = ArithmeticDecoder(fo)
        for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
            q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
            decoded_symbol = decoder.pull(q_cdf)
            assert decoded_symbol == symbol, idx
        assert decoder.pull(torch.zeros(1)) is None


if __name__ == "__main__":
    test()
