from typing import List, Optional

import numpy as np
from utils.random_utils import CryptoRandomGenerator


class BitStream:
    """
    Represents a (dynamic) stream of bits, internally represented as a list of booleans.
    Useful for variable-length encoding.
    """

    initial_size: int = 1024
    expand_size: int = 1024
    size: int = 0

    def __init__(self, from_bytes: Optional[bytes] = None):
        if from_bytes is not None:
            self.bits = np.unpackbits(np.frombuffer(from_bytes, dtype=np.uint8)).astype(
                bool
            )
            self.size = len(self.bits)
        else:
            self.bits = np.empty((self.initial_size,), dtype=bool)

    def __iadd__(self, bits: List[bool]) -> "BitStream":
        if self.size + len(bits) > len(self.bits):
            expand_size = max(self.expand_size, len(bits))
            self.bits = np.append(self.bits, np.empty((expand_size,), dtype=bool))

        self.bits[self.size : self.size + len(bits)] = bits
        self.size += len(bits)

        return self

    def __len__(self) -> int:
        return self.size

    def pad(self, block_size: int, rng: CryptoRandomGenerator) -> None:
        padding = (block_size - self.size % block_size) % block_size
        for _ in range(padding):
            self.bits[self.size] = rng.randbits(1)
            self.size += 1

    def get_bits(self) -> List[bool]:
        return self.bits[: self.size]

    def get_bytes(self) -> bytes:
        assert self.size % 8 == 0, "The stream must be padded to bytes"

        return np.packbits(self.bits[: self.size]).tobytes()
