#import stuff
import torch
import typing
import argparse
from PIL import Image
from .wm_provider import WmProvider
from utils.image_utils import torch_to_PIL

from Crypto.Cipher import ChaCha20

import numpy as np
from scipy.stats import norm,truncnorm
from functools import reduce

parser = argparse.ArgumentParser(add_help=False)

parser.add_argument('--num_replications', default=64, type=int, help="The number of replications of the message bits to get barcode image")
parser.add_argument('--message_width_in_bytes', default=32, type=int, help="Message width in bytes")
parser.add_argument("--offset_debug",type=int,default= None,help="how many nonces shall we consider for checking? Only serves for debugging.")


class GsReferernceProvider(WmProvider):
    """
    Original by https://github.com/bsmhmmlf/Gaussian-Shading/ and heavily modified for debugging and getting access to internals.
    """

    def __init__(self,
                 message_width_in_bytes: int = 32,  # (channels/f_c) * f_h * f_w // 8
                 num_replications: int = 64,
                 fpr = 1e-6,
                 number_of_users = 1000,
                 offset: int = 0,
                 offset_debug: int = None,
                 message: typing.Optional[str] = None,
                 key: str = None,
                 nonce: str = None,
                 **kwargs):
        """
        This provider uses a fixed list of keys, messages, and nonces for the watermarking process to eliminate every possibliy of false negatives due to wrong seeds.
        Generate more with "generate_secrets.py" as needed.

        These are the params for the the watermark:
        - message_width_in_bytes = Num users to distinguish
        - num_replications = strength of error correction

        message_width_in_bytes * 8 * num_replications  == num_channels * latent_resolution * latent_resolution
        Example:
        -> 32 (message_width_in_bytes) * 8 * 64 (num_replication)  = 16384 = 4 (num_channels) * 64 (latent_resolution) * 64 (latent_resolution)

        @param message_width_in_bytes: width of the message in bytes
        @param num_replications: num_replications
        @param offset: offset in the list of messages, keys, and nonces
        @param message: message to use all batch_size
        @param key: key to use all batch_size
        @param nonce: nonce to use for all batch_size
        @param offset_debug: How far shall we search for the correct message / nonce?
        """
        super().__init__(**kwargs)

        self.message_width_in_bytes = message_width_in_bytes
        self.message_width_in_bits = int(message_width_in_bytes * 8)
        self.barcode_width = self.message_width_in_bits
        self.bit_accuracy_threshold = self.lookup_bitacc_threshold(number_of_users,fpr)

        self.num_replications = num_replications
        self.offset = offset  # amount of message, keys, nonces used is decided by batch_size starting from offset
        self.offset_debug = offset_debug

        self.latentLength = self.message_width_in_bits * self.num_replications
        # use a predefined list of crypto params
        if message is None:
            from .messages_short import MESSAGES as MESSAGES_SHORT
            from .messages_long import MESSAGES as MESSAGES_LONG

            #MESSAGES = [m[:self.message_width_in_bytes] for m in MESSAGES]  # trim to message_width_in_bytes to fit our required message length
            from .keys_2 import KEYS
            from .nonces_2 import NONCES
            self.messages_long = MESSAGES_LONG
            self.messages_short = MESSAGES_SHORT
            self.key = KEYS[0]
            self.nonces = NONCES

        #else, enter wierd deployment modes. provide nonce, message and keys as arrays of length batch_size
        else:
            assert key is not None and nonce is not None

            print("individual keys" )
            
            from .messages_long import MESSAGES as MESSAGES_LONG

            self.messages_long = MESSAGES_LONG
            #self.messages_short = message
            self.key = key
            self.nonces = nonce
        

    def lookup_bitacc_threshold(self,users,fpr):
        #solved with a static dict as we'd need more precission to actually compute these values.
        bit_acc_dict = {
            (1000,1e-06):0.68359375,
            (1000,1e-16):0.7734375  ,
            (1000,1e-32):0.8671875  ,
            (1000,1e-64):0.9765625 ,
            (10000,1e-06):0.6953125,
            (10000,1e-16):0.78125,
            (10000,1e-32):0.87109375,
            (10000,1e-64):0.98046875,
            (100000,1e-06):0.70703125,
            (100000,1e-16):0.7890625,
            (100000,1e-32):0.875,
            (100000,1e-64):0.984375,
            (1,1e-06):0.6484375 ,
	        (1,1e-16):0.75,
	        (1,1e-32):0.8515625,
	        (1,1e-64):0.97265625
        }
        try:
            return bit_acc_dict[(users,fpr)]
        except KeyError:
            print((users,fpr),"fpr for this user unavailible, go for default.")
            return 0.8


    def get_wm_type(self) -> str:
        return "GSreference"
    
# get_wm_latents aka generate a latent
    def get_wm_latents(self, **kwargs) -> typing.Dict[str, any]:
        current_messages = []
        latents_torch = []
        for i in range(self.offset,self.batch_size + self.offset):
             # get the message, key and nonce
            message_bytes = self.messages_long[i]  # message_bytes = k
            # remember the message bits in apropreate format
            #TODO
            #current_messages.append(''.join(format(byte, '08b') for byte in message_bytes))
            current_messages.append(message_bytes)
            #hex to bits
            nonce_bytes = self.nonces[i]
            #repition code
            bit_array = np.unpackbits(np.frombuffer(message_bytes, dtype=np.uint8))
            sd = torch.tensor(bit_array, dtype=torch.float32).reshape((1, 4, 8, 8))
            sd2 = sd.repeat(1,1,8,8)
            #encrypt
            if isinstance(self.key,list):
                key = self.key[i]
            else:
                key = self.key
           # print(key,nonce_bytes)
            cipher = ChaCha20.new(key=key, nonce=nonce_bytes)
            sd3 = sd2.flatten().numpy()
            sd3 = np.int8(sd3)
            m_byte = cipher.encrypt(np.packbits(sd3).tobytes())
            m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8))

            #sample gaussian
            latent = self.__truncSampling(m_bit)
            latents_torch.append(latent.squeeze(0))

        # finalize
        latents_torch = torch.stack(latents_torch, dim=0)
        latents_torch = latents_torch.to(dtype=self.dtype)

        latents_PIL = torch_to_PIL(latents_torch)

        results_dict = {"zT_torch": latents_torch,
                        "zT_PIL": latents_PIL,
                        "zT": latents_PIL,
                        "message_bits_str_list": current_messages
                        }
    
        return results_dict
    

    def __truncSampling(self, message):
        '''according to GS reference implementation. Looks pretty complicated to me though.'''
        z = np.zeros(self.latentLength)
        denominator = 2.0
        ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
        for i in range(self.latentLength):
            dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1])
            dec_mes = int(dec_mes)
            z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
        z = torch.from_numpy(z).reshape(1, 4, 64, 64)#.half()
        return z.to(self.device)
    
    def _decode_repetition(self,watermark_sd):
        '''according to GS reference implementation will decode the repetition. Hopefully does what it promises.'''
        ch_stride = 4 // 1
        hw_stride = 64 // 8
        ch_list = [ch_stride] * 1
        hw_list = [hw_stride] * 8
        split_dim1 = torch.cat(torch.split(watermark_sd, tuple(ch_list), dim=1), dim=0)
        split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0)
        split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0)
        vote = torch.sum(split_dim3, dim=0).clone()/64
        vote[vote <= 0.5] = 0
        vote[vote > 0.5] = 1
        return vote
    
    def __recover_messages_from_latents(self,latents: torch.Tensor) -> typing.List[typing.Union[float, bytes]]:
            # iterate keys, nonces
        rtn_accuracies = list()
        rtn_messages = list()

        for latent in latents:
            # tokenise to integer
            latent_bits = (latent > 0).int()
            latent_bits = latent_bits.flatten().cpu().numpy()

            perfect_match = 0
            perfect_acc = 0
            #decrypt
            for nonce in self.nonces[:self.offset_debug]:
                cipher = ChaCha20.new(key=self.key, nonce=nonce)
                sd_byte = cipher.decrypt(np.packbits(latent_bits).tobytes())
                sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
                sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, 64, 64).to(torch.uint8)
                #decode
                reversed_message = self._decode_repetition(sd_tensor)
                reversed_message=reversed_message.flatten().cpu().numpy()
                #match
                best_match = 0
                best_acc = 0
                for i,mes in enumerate(self.messages_short):
                    bit_acc = self._calculate_bit_accuracy(mes,reversed_message)
                    if bit_acc > best_acc:
                        best_match = i
                        best_acc = bit_acc
                if best_acc > perfect_acc:
                    perfect_match = best_match
                    perfect_acc = best_acc
        
            rtn_accuracies.append(perfect_acc)
            rtn_messages.append(self.messages_short[perfect_match]) # buffer message somewhere? or hope that the os does that

        return rtn_accuracies,rtn_messages

   


    def _calculate_bit_accuracy(self,
                                 db_message: bytes,
                                 extracted_message: np.typing.ArrayLike) -> float:
        """
        Gett bit accuracy between extracted bits and the original message hex

        @param original_message_hex: original message in hex
        @param extracted_message_bits_str: extracted message in bits

        @return: bit accuracy
        """
        db_message = np.unpackbits(np.frombuffer(db_message, dtype=np.uint8))

        # Ensure both binary strings are of the same length, should be the case though.
        min_length = min(len(db_message), len(extracted_message))
        db_message = db_message[:min_length]
        extracted_message = extracted_message[:min_length]
        
        # Calculate bit accuracy
        matching_bits = np.count_nonzero(db_message == extracted_message)
        bit_accuracy = matching_bits / min_length
        
        return bit_accuracy      
    

#calc bit acc (uses the other. Hopefully no need to change both of them)
    def get_accuracies(self, latents: typing.Union[torch.Tensor, np.array],**kwargs) -> typing.Dict[str, any]:
        rtn_accuracies,rtn_messages = self.__recover_messages_from_latents(latents)
        
        return {
            "accuracies": rtn_accuracies,
            "bit_accuracies": rtn_accuracies,
            "message_bits_str_list": rtn_messages
        }
    
   
