import heapq
import io

import torch


def get_entropy_bits_batched_ref(x: torch.Tensor) -> torch.Tensor:
    """

    :param x: (R, C)
    :return: entropy_bits: (R)
    """

    c_dtype = torch.float64
    entropy_bits: torch.Tensor = torch.empty(*x.shape[:-1], dtype=c_dtype, device=x.device)
    for i in range(x.size(0)):
        x_i: torch.Tensor = x[i]
        length_i: torch.Tensor = torch.as_tensor(x_i.numel(), dtype=c_dtype, device=x.device)
        _, counts_i = x_i.unique(sorted=True, return_counts=True)
        frequencies_i: torch.Tensor = counts_i.to(dtype=c_dtype) / length_i
        entropy_bits[i] = -(frequencies_i * frequencies_i.log2()).sum()
    return entropy_bits


def get_entropy_bits(x: torch.Tensor, dim: int = None) -> torch.Tensor:
    """

    :param x: (..., C)
    :param dim: int
    :return: entropy_bits: (...)
    """

    device: torch.device = x.device
    f_dtype, i_dtype = torch.float64, torch.int64
    if dim is None:
        x = x.flatten()
        dim = -1
    x = x.to(dtype=i_dtype, copy=True)
    offset: int = x.min().item()
    n_bins: int = x.max().item() - offset + 1
    x -= offset
    ones: torch.Tensor = torch.ones((), dtype=i_dtype, device=device).expand_as(x)
    counts: torch.Tensor = torch.zeros(*x.shape[:dim], n_bins, *(x.shape[dim+1:] if dim != -1 else []), dtype=i_dtype, device=device)  # int64, (..., B)
    counts.scatter_add_(dim, x, ones)  # int64, (..., B)
    log_counts = torch.where(counts.to(dtype=torch.bool), counts, 1).to(dtype=f_dtype).log2()  # fp64, (..., B)
    length: torch.Tensor = torch.as_tensor(x.size(dim), dtype=f_dtype, device=device)  # fp64, ()
    entropy_bits: torch.Tensor = length.log2() - (counts * log_counts).sum(dim=dim) / length  # fp64, (...)
    return entropy_bits


def huffman_code(symbols: list[str | int], counts: list[float | int]) -> dict[str | int, str]:
    """
    Build the Huffman tree
    :param symbols: list
    :param counts: list
    :return: dict
    """
    heap = [[c, [[sym, '']]] for sym, c in zip(symbols, counts)]
    heapq.heapify(heap)
    while len(heap) > 1:
        c_lo, nodes_lo = heapq.heappop(heap)
        c_hi, nodes_hi = heapq.heappop(heap)
        for pair in nodes_lo:
            pair[1] = '0' + pair[1]
        for pair in nodes_hi:
            pair[1] = '1' + pair[1]
        heapq.heappush(heap, [c_lo + c_hi, nodes_lo + nodes_hi])
    _, final_nodes = heap[0]
    return {sym: code for sym, code in final_nodes}


def get_huffman_bits(x: torch.Tensor) -> float:
    """

    :param x: (...)
    :return: float
    """

    symbols, counts = x.unique(sorted=True, return_counts=True)
    frequencies: torch.Tensor = counts / x.numel()
    symbols_list, frequencies_list = symbols.tolist(), frequencies.tolist()
    codes: dict = huffman_code(symbols_list, frequencies_list)
    huffman_bits: float = sum(freq * len(codes[s]) for s, freq in zip(symbols_list, frequencies_list))
    return huffman_bits


class _BitWriter:
    """
    Helper to write bits to a binary stream in a streaming fashion.
    Accumulates bits and writes full bytes when possible.
    """
    def __init__(self, out: io.BytesIO):
        self.out: io.BytesIO = out
        self.buffer: int = 0  # current byte buffer
        self.nbits: int = 0  # number of bits currently in buffer
        self.length: int = 0  # number of total bits

    def write_bits(self, bits: str) -> None:
        """Write a string of '0'/'1' bits to the output stream."""
        for b in bits:
            self.buffer: int = (self.buffer << 1) | (1 if b == '1' else 0)
            self.nbits += 1
            if self.nbits == 8:
                self.out.write(bytes([self.buffer]))
                self.buffer: int = 0
                self.nbits: int = 0
                self.length += 8

    def flush(self) -> int:
        """Flush the remaining bits, padding with zeros."""
        self.length += self.nbits
        if self.nbits > 0:
            self.buffer <<= 8 - self.nbits
            self.out.write(bytes([self.buffer]))
            self.buffer = 0
            self.nbits = 0
        return self.length


class _BitReader:
    """
    Helper to read bits from a binary stream in a streaming fashion.
    Reads bytes and yields bits one by one.
    """
    def __init__(self, inp: io.BytesIO, length: int = -1):
        self.inp: io.BytesIO = inp
        self.length: int = length
        self.buffer: int = 0
        self.nbits: int = 0

    def read_bit(self) -> str:
        """Read a single bit as '0' or '1'. Returns empty string on EOF or padding."""
        if self.length != 0:
            self.length -= 1
        else:
            return ''

        if self.nbits == 0:
            byte = self.inp.read(1)
            if not byte:
                return ''
            self.buffer = byte[0]
            self.nbits = 8

        # extract highest-order bit
        bit = '1' if (self.buffer & 0x80) else '0'
        self.buffer <<= 1
        self.nbits -= 1
        return bit


def encode_huffman_stream(data: list[str | int] | str, codes: dict[str | int, str], out: io.BytesIO) -> int:
    """
    Streamingly encode `data` using `codes` and write packed bytes to `out`.

    Returns:
        Number of padding bits written at the end (0-7).
    """
    writer: _BitWriter = _BitWriter(out)
    for char in data:
        writer.write_bits(codes[char])
    return writer.flush()


def decode_huffman_stream(inp: io.BytesIO, length: int, codes: dict[str | int, str]) -> list[str | int]:
    """
    Streamingly decode from packed bytes in `inp` (with `pad_len` padding bits) using `codes`.

    Returns:
        The decoded string.
    """
    # Build Huffman trie for efficient decoding, e.g., tree = { '0': subtree, '1': subtree or symbol }
    trie: dict = {}
    for sym, bitstr in codes.items():
        node: dict = trie
        for b in bitstr:
            node: dict = node.setdefault(b, {})
        node['sym'] = sym

    reader: _BitReader = _BitReader(inp, length)
    decoded: list[str | int] = []
    node: dict = trie
    # iterate over bits until EOF
    while True:
        bit: str = reader.read_bit()
        if not bit:
            break
        node: dict = node[bit]
        if 'sym' in node:
            decoded.append(node['sym'])
            node: dict = trie
    return decoded


def search_scale_entropy(x: torch.Tensor, target_entropy: float, dim: int = -1) -> torch.Tensor:
    """

    :param x: (..., C)
    :param target_entropy: float
    :param dim: int
    :return: scale: (...)
    """

    i_dtype: torch.dtype = torch.int64
    length: int = x.size(dim)
    std: torch.Tensor = torch.linalg.vector_norm(x, dim=dim, keepdim=True) * length ** -.5  # (..., 1)
    scale_min: torch.Tensor = torch.zeros_like(std)  # (..., 1)
    scale_max: torch.Tensor = std * 3.  # (..., 1)

    max_iter: int = 100
    for _ in range(max_iter):
        scale_mid: torch.Tensor = (scale_min + scale_max) * .5  # (..., 1)
        qx: torch.Tensor = (x / scale_mid).round().to(dtype=i_dtype)  # (..., C)
        entropy_bits: torch.Tensor = get_entropy_bits(qx, dim=dim).unsqueeze(dim=dim)  # (..., 1)
        scale_min: torch.Tensor = torch.where(entropy_bits < target_entropy, scale_min, scale_mid)  # (..., 1)
        scale_max: torch.Tensor = torch.where(entropy_bits < target_entropy, scale_mid, scale_max)  # (..., 1)
        if scale_min.equal(scale_max):
            break
    return scale_max[..., 0]  # (...)


def _unit_test():
    batch_dims: tuple[int] = (5,)
    length: int = 1000
    x: torch.Tensor = torch.randint(-5, 6, (*batch_dims, length))
    assert get_entropy_bits(x, dim=-1).allclose(get_entropy_bits_batched_ref(x))

    entropy_bits: float = get_entropy_bits(x).item()
    huffman_bits: float = get_huffman_bits(x)
    print(f'Entropy: {entropy_bits:.4f} bits, Huffman: {huffman_bits:.4f} bits')

    # Sample Huffman code dictionary
    symbols, counts = x.unique(sorted=True, return_counts=True)
    codes: dict[str | int, str] = huffman_code(symbols.tolist(), counts.tolist())

    # Encode to a BytesIO stream
    out_stream: io.BytesIO = io.BytesIO()
    length: int = encode_huffman_stream(x.flatten().tolist(), codes, out_stream)
    packed_bytes: bytes = out_stream.getvalue()
    assert len(packed_bytes) == (length + 7) // 8
    huffman_bits_verify: float = length / x.numel()
    assert f'{huffman_bits:.4f}' == f'{huffman_bits_verify:.4f}'
    print(f'#avg_bits={huffman_bits_verify}, #bits={length}, #bytes={len(packed_bytes)}')

    # Decode from the BytesIO stream
    in_stream: io.BytesIO = io.BytesIO(packed_bytes)
    y: torch.Tensor = torch.as_tensor(decode_huffman_stream(in_stream, length, codes), dtype=x.dtype, device=x.device).reshape_as(x)
    assert y.equal(x)

    print('OK!')


if __name__ == '__main__':
    _unit_test()
