from typing import List, Optional, Tuple

import numpy as np
from distribution import QuantizedDistribution
from encryption import Encryptor
from tqdm import tqdm
from utils.bitstream import BitStream
from utils.random_utils import CryptoRandomGenerator

from . import BaseEncoder
from .huffman import HuffmanCoder


class FixedLengthEncoder(BaseEncoder):
    def __init__(
        self,
        num_bits,
        encryptor: Optional[Encryptor] = None,
        smooth_distribution: bool = False,
        verbose: bool = True,
    ):
        super().__init__(encryptor)
        assert num_bits in [8, 16, 32]
        self.num_bits = num_bits
        self.smooth_distribution = smooth_distribution  # Add epsilon to distribution
        self.verbose = verbose

    @staticmethod
    def _encode_token(distribution, token_id, rng):
        l, r = distribution.get_interval(token_id)
        if l == r:
            # Unsatisfiable
            raise RuntimeError(
                "Unsatisfiable token (increase quantization bits or enable smoothing)"
            )

        return l + rng.randbelow(r - l)

    def encode_message_impl(
        self,
        initial_probs,
        token_ids: List[int],
        probs_func,
        rng: CryptoRandomGenerator,
    ) -> bytes:
        probs = initial_probs
        encoded_tokens = b""
        for i, token_id in tqdm(
            enumerate(token_ids), disable=not self.verbose, desc="Encoding"
        ):
            distribution = QuantizedDistribution(
                probs, 2**self.num_bits, smooth=self.smooth_distribution
            )
            enc = self._encode_token(distribution, token_id, rng)
            encoded_tokens += enc.to_bytes(self.num_bits // 8, "big")
            if i < len(token_ids) - 1:  # No need to evaluate probs for last token
                probs = probs_func(token_id)
        return encoded_tokens

    def decode_message_impl(
        self,
        initial_probs,
        encoded_plaintext: bytes,
        probs_func,
        terminator_token_id: int = -1,
        max_length: int = -1,
        return_probs: bool = False,
    ) -> Tuple[List[int], float, Optional[np.array]]:
        nbytes = self.num_bits // 8
        assert len(encoded_plaintext) % nbytes == 0
        probs = initial_probs
        token_ids = []
        perplexity = 0.0
        if return_probs:
            all_probs = [ probs ]
        else:
            all_probs = None
        for offset in tqdm(
            range(0, len(encoded_plaintext), nbytes),
            disable=not self.verbose,
            desc="Decoding",
        ):
            encoded_token = int.from_bytes(
                encoded_plaintext[offset : offset + nbytes], "big"
            )
            distribution = QuantizedDistribution(
                probs, 2**self.num_bits, smooth=self.smooth_distribution
            )
            token_id = distribution.decode(encoded_token)
            token_ids.append(token_id)
            perplexity += np.log2(probs[token_id])
            if token_id == terminator_token_id or len(token_ids) == max_length:
                break
            if offset < len(encoded_plaintext) - nbytes:
                probs = probs_func(token_id)
                if return_probs:
                    all_probs.append(probs)

        perplexity = np.exp2(-perplexity / len(token_ids))
        return token_ids, perplexity, all_probs

    def statistical_test1_impl(self, encoded_plaintext: bytes) -> np.array:
        # if self.num_bits > 16:
        #    raise RuntimeError('Use 8 or 16 bits for quantization, or results will be too sparse for a significant test.')

        # Hardcoded sliding window
        nbytes = 1  # self.num_bits // 8
        freqs = np.zeros((256**nbytes,), dtype=np.uint32)
        assert len(encoded_plaintext) % nbytes == 0
        for offset in range(0, len(encoded_plaintext), nbytes):
            encoded_token = int.from_bytes(
                encoded_plaintext[offset : offset + nbytes], "big"
            )
            freqs[encoded_token] += 1
        return freqs


class CompressedEncoder(BaseEncoder):
    def __init__(self, encryptor: Optional[Encryptor] = None, verbose: bool = True):
        super().__init__(encryptor)
        # Hardcoded, required for defining the c type
        self.num_bins = 2**32
        self.average_entropy = None
        self.verbose = verbose

    def encode_message_impl(
        self,
        initial_probs,
        token_ids: List[int],
        probs_func,
        rng: CryptoRandomGenerator,
    ) -> bytes:
        probs = initial_probs
        encoded_tokens = BitStream()
        # total_entropy = 0
        for i, token_id in tqdm(
            enumerate(token_ids),
            disable=not self.verbose,
            desc="Encoding",
            total=len(token_ids),
        ):
            distribution = QuantizedDistribution(probs, self.num_bins)
            frequencies = distribution.get_frequencies()
            tree = HuffmanCoder(frequencies)
            # total_entropy += tree.compute_entropy()
            # print(token_id, probs, len(tree.encode_symbol(token_id)))
            encoded_tokens += tree.encode_symbol(token_id)
            if i < len(token_ids) - 1:  # No need to evaluate probs for last token
                probs = probs_func(token_id)

        encoded_tokens.pad(8, rng)  # Pad to bytes
        # self.average_entropy = total_entropy / len(token_ids)
        return encoded_tokens.get_bytes()

    def decode_message_impl(
        self,
        initial_probs,
        encoded_plaintext: bytes,
        probs_func,
        terminator_token_id: int = -1,
        max_length: int = -1,
        return_probs: bool = False,
    ) -> Tuple[List[int], float, Optional[np.array]]:
        probs = initial_probs
        offset = 0
        bits = BitStream(from_bytes=encoded_plaintext).get_bits()
        # print("len(bits):", len(bits))
        token_ids = []
        # total_entropy = 0
        perplexity = 0.0
        if return_probs:
            all_probs = [ ]
        else:
            all_probs = None
        pbar = tqdm(disable=not self.verbose, desc="Decoding", total=len(bits))
        while True:
            distribution = QuantizedDistribution(probs, self.num_bins)
            frequencies = distribution.get_frequencies()
            tree = HuffmanCoder(frequencies)
            if return_probs:
                all_probs.append(tree.get_effective_probabilities())
            # total_entropy += tree.compute_entropy()
            token_id, offset = tree.decode_symbol(bits, offset)
            # print("token_id:", token_id, "probs", probs, "offset:", offset)
            if token_id < 0:
                # End of stream reached
                break
            token_ids.append(token_id)
            perplexity += np.log2(probs[token_id])
            if token_id == terminator_token_id or len(token_ids) == max_length:
                break
            if offset < len(bits):
                probs = probs_func(token_id)
            else:
                break

            pbar.update(offset - pbar.n)
        # self.average_entropy = total_entropy / len(token_ids)
        perplexity = np.exp2(-perplexity / len(token_ids))
        return token_ids, perplexity, all_probs

    def statistical_test1_impl(self, encoded_plaintext: bytes) -> np.array:
        nbytes = 1  # Use byte-aligned sliding window
        freqs = np.zeros((2 ** (8 * nbytes)), dtype=np.uint32)
        for offset in range(0, len(encoded_plaintext), nbytes):
            encoded_token = int.from_bytes(
                encoded_plaintext[offset : offset + nbytes], "big"
            )
            freqs[encoded_token] += 1
        return freqs
