from abc import ABC, abstractmethod
from typing import List, Optional, Tuple

import numpy as np
import scipy.stats
from encryption import Encryptor
from utils.bitstream import BitStream
from utils.random_utils import CryptoRandomGenerator


# create Abstract Base Class for all encoders
class BaseCoder(ABC):
    @abstractmethod
    def __init__(self, freqs: List[int]):
        pass

    @abstractmethod
    def encode_symbol(self, symbol: int) -> List[int]:
        pass

    @abstractmethod
    def decode_symbol(
        self, bits: List[int], offset: int = 0
    ) -> Tuple[Optional[int], int]:
        pass


class BaseEncoder:
    def __init__(self, encryptor: Optional[Encryptor] = None):
        self.encryptor = encryptor

    def _encode_header(self, prompt_token_ids: List[int]) -> bytes:
        # 4 bytes for length
        buffer = len(prompt_token_ids).to_bytes(4, "big")
        for token_id in prompt_token_ids:
            # 2 bytes for the token ID
            assert token_id < 2**16
            buffer += token_id.to_bytes(2, "big")
        return buffer

    def _decode_header(self, ciphertext: bytes) -> Tuple[List[int], bytes]:
        num_tokens = int.from_bytes(ciphertext[:4], "big")
        token_ids = []
        for i in range(num_tokens):
            token_id = int.from_bytes(ciphertext[4 + i * 2 : 6 + i * 2], "big")
            token_ids.append(token_id)
        return token_ids, ciphertext[4 + 2 * num_tokens :]

    def encode_message(
        self,
        prompt_token_ids,
        token_ids,
        probs_func,
        rng: CryptoRandomGenerator,
        target_length: int = 0,
        encrypt: bool = True,
    ) -> bytes:
        if encrypt and self.encryptor is None:
            raise ValueError("No encryptor was supplied")

        header = self._encode_header(prompt_token_ids)

        # Prompt model
        for token_id in prompt_token_ids:
            initial_probs = probs_func(token_id)

        encoded_plaintext = self.encode_message_impl(
            initial_probs, token_ids, probs_func, rng
        )
        if encrypt:
            ciphertext = self.encryptor.encrypt(encoded_plaintext, rng, target_length)
        else:
            ciphertext = encoded_plaintext

        return header + ciphertext

    def decode_message(
        self,
        message: bytes,
        probs_func,
        decrypt: bool = True,
        terminator_token_id: int = -1,
        max_length: int = -1,
    ) -> Tuple[List[int], List[int]]:
        if decrypt and self.encryptor is None:
            raise ValueError("No encryptor was supplied")

        prompt_token_ids, ciphertext = self._decode_header(message)
        if decrypt:
            encoded_plaintext = self.encryptor.decrypt(ciphertext)
        else:
            encoded_plaintext = ciphertext

        # Prompt model
        for token_id in prompt_token_ids:
            initial_probs = probs_func(token_id)

        token_ids, perplexity, _ = self.decode_message_impl(
            initial_probs,
            encoded_plaintext,
            probs_func,
            terminator_token_id,
            max_length,
        )
        return prompt_token_ids, token_ids, perplexity

    @abstractmethod
    def encode_message_impl(
        self,
        initial_probs,
        token_ids: List[int],
        probs_func,
        rng: CryptoRandomGenerator,
    ) -> bytes:
        pass

    @abstractmethod
    def decode_message_impl(
        self,
        initial_probs,
        encoded_plaintext: bytes,
        probs_func,
        terminator_token_id: int,
        max_length: int,
        return_probs: bool = False,
    ) -> Tuple[List[int], float, Optional[np.array]]:
        pass

    @staticmethod
    def runs_test(seq) -> float:
        """
        Performs a runs test on the provided sequence, to check
        whether the provided samples are correlated.

        https://en.wikipedia.org/wiki/Wald%E2%80%93Wolfowitz_runs_test

        Assumes that the input sequence is a boolean array.
        Returns a p-value. If it is lower than a threshold (e.g. p < 0.01),
        it is likely that samples are correlated and thus not truly random.
        """
        prev_value = seq[0]
        r0 = 0
        r1 = 0
        for value in seq[1:]:
            if value != prev_value:
                if prev_value:
                    r1 += 1
                else:
                    r0 += 1
                prev_value = value
        if prev_value:
            r1 += 1
        else:
            r0 += 1

        n1 = int(sum(seq))
        n0 = int(len(seq) - n1)

        mu = 1 + 2 * n0 * n1 / len(seq)
        var = 2 * n0 * n1 * (2 * n0 * n1 - n0 - n1) / (len(seq) ** 2 * (n0 + n1 - 1))

        z = abs((r0 + r1 - mu) / np.sqrt(var))
        p = 2 * scipy.stats.norm.sf(z)
        return p

    def statistical_test1(
        self, message: bytes, decrypt: bool = True
    ) -> Tuple[float, float]:
        """
        Checks that the provided ciphertext is indistinguishable from random noise. Returns a p-value.
        This test is model-free (the encoded tokens do not need to be decoded).
        Internally, it performs a frequency test, i.e. a chi-square test to check
        whether the decrypted stream (i.e. encoded plaintext) is indistinguishable from
        samples from a uniform distribution.
        Additionally, it performs a correlation test to check whether samples are independent.
        Returns two p-values (one for the frequency test and one for the correlation test).
        If any is below a threshold (e.g. p < 0.01), the stream is likely non-random.
        """

        if decrypt and self.encryptor is None:
            raise ValueError("No encryptor was supplied")

        _, ciphertext = self._decode_header(message)

        if decrypt:
            encoded_plaintext = self.encryptor.decrypt(ciphertext)
        else:
            encoded_plaintext = ciphertext

        # Perform frequency test
        freqs = self.statistical_test1_impl(encoded_plaintext)
        p_freqs = scipy.stats.chisquare(freqs.astype(np.float64)).pvalue

        # Perform correlation test
        bits = BitStream(from_bytes=encoded_plaintext).get_bits()
        p_corr = BaseEncoder.runs_test(bits)

        return p_freqs, p_corr

    @abstractmethod
    def statistical_test1_impl(self, encoded_plaintext: bytes) -> np.array:
        pass

    def statistical_test2(
        self,
        message: bytes,
        probs_func,
        stride: int = 0,
        max_length: int = -1,
        decrypt: bool = True,
    ) -> float:
        """
        Checks that the provided ciphertext is as if it was sampled from the model. Returns a p-value.
        This test is model-based (the encoded tokens are decoded using the model).
        Internally, it performs a chi-square test to check whether the sampled frequencies correspond to the model frequencies.
        """
        if decrypt and self.encryptor is None:
            raise ValueError("No encryptor was supplied")

        prompt_token_ids, ciphertext = self._decode_header(message)
        if decrypt:
            encoded_plaintext = self.encryptor.decrypt(ciphertext)
        else:
            encoded_plaintext = ciphertext

        # Prompt model
        for token_id in prompt_token_ids:
            initial_probs = probs_func(token_id)

        token_ids, _, probs = self.decode_message_impl(
            initial_probs,
            encoded_plaintext,
            probs_func,
            max_length=max_length,
            return_probs=True,
        )

        total_entropy = 0.0
        mu_sum = 0.0
        var_sum = 0.0
        all_mu = []
        all_var = []
        all_entropies = []
        for i, (token_id, p) in enumerate(zip(token_ids, probs)):
            total_entropy += -np.log2(p[token_id]).item()
            mu = -np.sum(p * np.log2(p))
            var = np.sum(p * (-np.log2(p) - mu) ** 2)
            mu_sum += mu
            var_sum += var
            if stride > 0 and (i + 1) % stride == 0:
                all_mu.append(mu_sum)
                all_var.append(var_sum)
                all_entropies.append(total_entropy)
        # Add last value (even if it is repeated)
        if stride != 1:
            all_mu.append(mu_sum)
            all_var.append(var_sum)
            all_entropies.append(total_entropy)

        print("dbg 0.005", scipy.stats.norm.ppf(0.005, mu_sum, np.sqrt(var_sum)))
        print("dbg 0.995", scipy.stats.norm.ppf(0.995, mu_sum, np.sqrt(var_sum)))
        print("dbg measured", total_entropy)

        p = 2 * np.minimum(
            scipy.stats.norm.sf(all_entropies, all_mu, np.sqrt(all_var)),
            scipy.stats.norm.cdf(all_entropies, all_mu, np.sqrt(all_var)),
        )
        return list(p)
