import ctypes
import heapq
from typing import List, Optional, Tuple

import numpy as np

from . import BaseCoder

lib = ctypes.cdll.LoadLibrary("./libhuffman.so")

lib.HuffmanCoder_new.restype = ctypes.c_void_p
lib.HuffmanCoder_new.argtypes = [ctypes.POINTER(ctypes.c_uint32), ctypes.c_int]

lib.HuffmanCoder_encode_symbol.restype = ctypes.c_int
lib.HuffmanCoder_encode_symbol.argtypes = [
    ctypes.c_void_p,
    ctypes.c_int,
    ctypes.POINTER(ctypes.c_bool),
]

lib.HuffmanCoder_decode_symbol.restype = ctypes.c_int
lib.HuffmanCoder_decode_symbol.argtypes = [
    ctypes.c_void_p,
    ctypes.POINTER(ctypes.c_bool),
    ctypes.c_int,
    ctypes.POINTER(ctypes.c_int),
]

lib.HuffmanCoder_get_code_lengths.restype = None
lib.HuffmanCoder_get_code_lengths.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint32)]

lib.HuffmanCoder_get_max_height.restype = ctypes.c_int
lib.HuffmanCoder_get_max_height.argtypes = [ctypes.c_void_p]

lib.HuffmanCoder_delete.restype = None
lib.HuffmanCoder_delete.argtypes = [ctypes.c_void_p]


class HuffmanCoder(BaseCoder):
    buffer: np.ndarray = None

    def __init__(self, x: np.ndarray):
        assert type(x) == np.ndarray
        assert x.dtype == np.uint32

        x_ctype = x.ctypes.data_as(ctypes.POINTER(ctypes.c_uint))

        self.obj = lib.HuffmanCoder_new(x_ctype, len(x))
        max_height = lib.HuffmanCoder_get_max_height(self.obj)

        self.buffer = np.empty((max_height,), dtype=bool)
        self.buffer_ctype = self.buffer.ctypes.data_as(ctypes.POINTER(ctypes.c_bool))
        self.vocab_size = x.shape[-1]

    def encode_symbol(self, x: int) -> np.ndarray:
        offset = lib.HuffmanCoder_encode_symbol(self.obj, x, self.buffer_ctype)

        return self.buffer[:offset]

    def decode_symbol(self, buffer: np.ndarray, offset: int) -> Tuple[int, int]:
        assert buffer.dtype == bool

        buffer_ctype = buffer.ctypes.data_as(ctypes.POINTER(ctypes.c_bool))
        offset = ctypes.c_int(offset)

        symbol = lib.HuffmanCoder_decode_symbol(
            self.obj, buffer_ctype, len(buffer), offset
        )

        return symbol, offset.value
    
    def get_effective_probabilities(self) -> np.ndarray:
        log_probs = np.zeros((self.vocab_size,), dtype=np.uint32)
        lib.HuffmanCoder_get_code_lengths(
            self.obj, log_probs.ctypes.data_as(ctypes.POINTER(ctypes.c_uint))
        )
        return np.exp2(-log_probs.astype(np.float64))

    def __del__(self):
        lib.HuffmanCoder_delete(self.obj)


class HuffmanCoderRef(BaseCoder):
    def __init__(self, freqs: List[int]):
        assert len(freqs) > 1

        # (frequency, token id, left child, right child, is_leaf, parent, bit)
        q = [[f, i, None, None, True, None, None] for i, f in enumerate(freqs)]
        index = [None] * len(freqs)
        heapq.heapify(q)
        unique_id = len(freqs)
        while len(q) > 1:
            left = heapq.heappop(q)
            right = heapq.heappop(q)
            node = [left[0] + right[0], unique_id, left, right, False, None, None]
            unique_id += 1
            heapq.heappush(q, node)
            left[-2], right[-2] = node, node  # Set parent
            left[-1], right[-1] = 0, 1  # Set bit

            # Add nodes to index
            if left[-3]:
                index[left[1]] = left
            if right[-3]:
                index[right[1]] = right
        if q[0][-3]:
            index[q[0][1]] = q[0]

        self.index = index
        self.root = q[0]
        self.freqs = freqs

    def encode_symbol(self, symbol: int) -> List[int]:
        bits = []
        node = self.index[symbol]
        assert node is not None
        while node[-2] is not None:
            bits.append(node[-1])
            node = node[-2]  # Parent
        bits.reverse()
        return bits

    def decode_symbol(
        self, bits: List[int], offset: int = 0
    ) -> Tuple[Optional[int], int]:
        node = self.root
        while node is not None and offset < len(bits):
            if bits[offset] == 0:
                node = node[2]  # Left
            elif bits[offset] == 1:
                node = node[3]  # Right
            else:
                raise ValueError("Invalid bit")
            offset += 1
            if node[-3]:  # Is leaf
                break

        # Note: returned symbol can be None if message is truncated
        symbol = node[1]
        if not node[-3]:  # Is not leaf
            return None, offset
        return symbol, offset

    def compute_entropy(self) -> float:
        """Returns the entropy of this tree expressed as avg. bits/symbol."""
        avg_bits_per_symbol = 0
        sum_freq = sum(self.freqs)
        for i, f in enumerate(self.freqs):
            node = self.index[i]
            assert node is not None
            count = 0
            while node[-2] is not None:
                count += 1
                node = node[-2]  # Parent
            avg_bits_per_symbol += (f / sum_freq) * count
        return avg_bits_per_symbol
