#import stuff
import torch
import typing
import argparse
from .wm_provider import WmProvider
from Crypto.Cipher import ChaCha20
from Crypto.Random import get_random_bytes

from scipy.stats import norm,truncnorm
from functools import reduce

from utils.image_utils import torch_to_PIL

from . import prc

import pickle
import numpy as np

parser = argparse.ArgumentParser(add_help=False)

#Todo import the gs_reference_parser as well, otherwise it doesn't work

class GsPpProvider(WmProvider):
    """
    Original by https://github.com/bsmhmmlf/Gaussian-Shading/ and heavily modified for debugging and getting access to internals.
    """

    def __init__(self,
                    filepath_keys = None,
                    offset: int = 0,
                    message: typing.Optional[str] = None,
                    key: str = None,
                    t=3,
                    **kwargs):
        """     

        works with fixed latent size 4 x 64 x 64
        message size is 256 bits. Half of the channels are used for prc, the other half for GS.

        GS General Params
        @param message_width_in_bytes: width of the message in bytes
        @param num_replications: num_replications
        @param offset: offset in the list of messages, keys, and nonces
        @param message: message to use all batch_size
        @param key: key to use all batch_size
        @param nonce: nonce to use for all batch_size
        """
        super().__init__(**kwargs)

        self.offset = offset

        #construct prc
        #len prc: half of the channels
        # seed length = 32 bit
        self.nonce_len = 32
        self.prc_len = 2*64*64
        if filepath_keys is None:
            (encoding_key_ori, decoding_key_ori) = prc.KeyGen_gspp(n = self.prc_len,
                                                          message_length= self.nonce_len,
                                                           t=t)  # Sample PRC keys
            self.key_enc = encoding_key_ori
            self.key_dec = decoding_key_ori
        else:
            with open(filepath_keys,"rb") as f:
                keydict = pickle.load(f)
            self.key_enc = keydict["key_enc"]
            self.key_dec = keydict["key_dec"]
        self.rep_factor = 32        
        self.messagelen = 256
        # 256 bits GS Channel
        # rep: 32


        self.latentLength = 2* self.prc_len
        #construct GS
        if message is None:
            #should still fit in there.
            from .messages_long import MESSAGES as MESSAGES_LONG

            #MESSAGES = [m[:self.message_width_in_bytes] for m in MESSAGES]  # trim to message_width_in_bytes to fit our required message length
            from .keys_2 import KEYS
            self.messages = MESSAGES_LONG
            self.key = KEYS[0]
        if filepath_keys is not None:
            self.key = keydict["key_chacha"]

 
    def dump_keys(self,filepath):
        with open(filepath,"wb") as f:
            keys = {"key_enc":self.key_enc,
                        "key_dec":self.key_dec,
                        "key_chacha":self.key}
            pickle.dump(keys,f)

    def load_keys(self,filepath):
        
            self

    def get_wm_type(self) -> str:
        return "GSpp"   
    

    def unpack_bytes(self,bites):
        return np.unpackbits(np.frombuffer(bites, dtype=np.uint8))

    def _calculate_bit_accuracy(self,
                                 db_message: bytes,
                                 extracted_message: np.typing.ArrayLike) -> float:
        """
        Gett bit accuracy between extracted bits and the original message hex

        @param original_message_hex: original message in hex
        @param extracted_message_bits_str: extracted message in bits

        @return: bit accuracy
        """
        db_message = self.unpack_bytes(db_message)

        # Ensure both binary strings are of the same length, should be the case though.
        min_length = min(len(db_message), len(extracted_message))
        db_message = db_message[:min_length]
        extracted_message = extracted_message[:min_length]
        
        # Calculate bit accuracy
        matching_bits = np.count_nonzero(db_message == extracted_message)
        bit_accuracy = matching_bits / min_length
        
        return bit_accuracy  

    def _truncSampling(self, message):
        '''according to GS reference implementation. Looks pretty complicated to me though.'''
        z = np.zeros(self.latentLength)
        denominator = 2.0
        ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
        for i in range(self.latentLength):
            dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1])
            dec_mes = int(dec_mes)
            z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
        z*=-1 # i guess they inverted the signs. but not clear from paper for me.
        z = torch.from_numpy(z).reshape(1, 4, 64, 64)#.half()
        return z.to(self.device)
    
    def __pad_nonce(self,nonce_bytes,final_len=8):
        return nonce_bytes + b'\x00' * (final_len - len(nonce_bytes)) 

    # get_wm_latents aka generate a latent
    def get_wm_latents(self, **kwargs) -> typing.Dict[str, any]:
        current_messages = []
        latents_torch = []
        for i in range(self.offset,self.batch_size + self.offset):
            #draw nonce
            #nonce_bytes = self.nonces[i]
            nonce_bytes = get_random_bytes(self.nonce_len//8)
            #print(nonce_bytes) works, gets transmitted
            padded_nonce_bytes = self.__pad_nonce(nonce_bytes) #GS++ nonce is too small for chacha.
            #encode the nonce
            nonce_bits = self.unpack_bytes(nonce_bytes)
            prc_codeword = prc.Encode_gspp(self.key_enc,nonce_bits)

           

            #encode the message with GS
            
             # get the message, key and nonce
            message_bytes = self.messages[i]  # message_bytes = k
            # remember the message bits in apropreate format
            #TODO
            #current_messages.append(''.join(format(byte, '08b') for byte in message_bytes))
            current_messages.append(message_bytes)
            #hex to bits
            
            #repition code
            bit_array = np.unpackbits(np.frombuffer(message_bytes, dtype=np.uint8))
            sd = np.repeat(bit_array,self.rep_factor)
            
            #encrypt
            cipher = ChaCha20.new(key=self.key, nonce=padded_nonce_bytes)
            m_byte = cipher.encrypt(np.packbits(sd).tobytes())
            m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8))


            # put stuff together
            latent_bits = np.concatenate([prc_codeword,m_bit])

            #sample gaussian
            latent = self._truncSampling(latent_bits)
            latents_torch.append(latent.squeeze(0))

        # finalize
        latents_torch = torch.stack(latents_torch, dim=0)
        latents_torch = latents_torch.to(dtype=self.dtype)

        latents_PIL = torch_to_PIL(latents_torch)

        results_dict = {"zT_torch": latents_torch,
                        "zT_PIL": latents_PIL,
                        "zT": latents_PIL,
                        "message_bits_str_list": current_messages
                        }
    
        return results_dict
    
    def __recover_messages_from_latent(self,latent: torch.Tensor):
        #flatten
        latent = latent.flatten().numpy(force=True)

        #split into two parts
        prc_latent = torch.tensor(latent[:self.prc_len],dtype=torch.float64)
        gs_lantent = latent[self.prc_len:]

        #prc.decode first part
        var = float(1.5)
        reversed_prc = prc.recover_posteriors(prc_latent, variances=var).flatten().cpu()
        decoding_result = prc.Decode_gspp(self.key_dec, reversed_prc)
        if decoding_result is None:
            bit_acc = 0
            message = 0 #TODO: Fix this to a proper invalid values
            return
        else:
            nonce_short = decoding_result
            #rint(np.packbits(nonce_short).tobytes()) works, is the same

        #vibe decode repetition code.

        #decrypt repetition
        nonce_long = self.__pad_nonce(np.packbits(nonce_short).tobytes())
        cipher = ChaCha20.new(key=self.key, nonce=nonce_long)
        #get the key stream to decide whether to flip signs or not.
        sd_byte = cipher.decrypt(bytearray(self.prc_len//8)) #decrypt a bunch of zeros
        keystream = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))

        #flip sign according to keystream
        gs_lantent *= (1 - 2 * keystream)

        message = np.zeros(self.messagelen)
        #for each bock: add all up and take sign.
        for j,i in enumerate(range(0,len(gs_lantent),self.rep_factor)):
            block = gs_lantent[i:i+self.rep_factor]
            bhat = np.sign(np.sum(block))
            message[j] = bhat

        message = message == 1 # discretise to correct format
        return message


    def get_accuracies(self, latents: typing.Union[torch.Tensor, np.array],**kwargs) -> typing.Dict[str, any]:
        rtn_messages = []
        rtn_accs = []
        for i,latent in enumerate(latents):
            rtn_message = self.__recover_messages_from_latent(latent)
            ori_message = self.messages[i+self.offset]
            bit_acc = self._calculate_bit_accuracy(ori_message,rtn_message)
            rtn_messages.append(rtn_message)
            rtn_accs.append(bit_acc)

        return {
            "accuracies": rtn_accs,
            "bit_accuracies": rtn_accs,
            "message_bits_str_list": rtn_messages
        }