import hashlib
from typing import Optional

from Crypto.Cipher import AES
from utils.random_utils import CryptoRandomGenerator


class Encryptor:
    def __init__(self, passphrase: Optional[str] = None, key: Optional[bytes] = None):
        if key is None:
            assert (
                passphrase is not None
            ), "You can only specify either key or passphrase"
            self.key = self._derive_key(passphrase)
        else:
            assert passphrase is None, "You can only specify either key or passphrase"
            self.key = key

    def _derive_key(self, passphrase: str) -> bytes:
        # There are better key derivation functions, but this is just a proof of concept...
        return hashlib.sha256(passphrase.encode()).digest()

    def encrypt(
        self, encoded_plaintext: bytes, rng: CryptoRandomGenerator, block_size: int = 0
    ) -> bytes:
        iv = rng.randbytes(AES.block_size)
        cipher = AES.new(self.key, AES.MODE_CBC, iv)

        if block_size == 0:
            # Use default AES block size
            block_size = AES.block_size
        else:
            # Pad specified block size to AES block size
            block_size = (
                block_size
                + (AES.block_size - block_size % AES.block_size) % AES.block_size
            )

        # As opposed to the standard practice in AES encryption,
        # here the padding length is NOT included in the message,
        # as it would provide a way to verify integrity (which we do not want)
        padding_length = (block_size - len(encoded_plaintext) % block_size) % block_size
        # Random bytes = random sampling from the model
        # Any other strategy would leak information
        padding = rng.randbytes(padding_length)

        ciphertext = iv + cipher.encrypt(encoded_plaintext + padding)
        return ciphertext

    def decrypt(self, ciphertext: bytes) -> bytes:
        iv = ciphertext[: AES.block_size]
        cipher = AES.new(self.key, AES.MODE_CBC, iv)

        encoded_plaintext = cipher.decrypt(ciphertext[AES.block_size :])
        return encoded_plaintext
