#import stuff
import torch
import typing
import argparse
from .wm_provider import WmProvider
from Crypto.Cipher import ChaCha20
from Crypto.Random import get_random_bytes
from Crypto.Hash import SHAKE128

from scipy.stats import norm,truncnorm
from functools import reduce

from utils.image_utils import torch_to_PIL

from . import prc
from . import gspp_provider

import numpy as np

parser = argparse.ArgumentParser(add_help=False)

#Todo import the gs_reference_parser as well, otherwise it doesn't work

class GsPpReferenceProvider(gspp_provider.GsPpProvider):
    """
    Original by https://github.com/bsmhmmlf/Gaussian-Shading/ and heavily modified for debugging and getting access to internals.
    """

    def __init__(self,
                    message: typing.Optional[str] = None,
                    **kwargs):
        """     

        works with fixed latent size 4 x 64 x 64
        message size is 256 bits. Half of the channels are used for prc, the other half for GS.

        GS General Params
        @param message_width_in_bytes: width of the message in bytes
        @param num_replications: num_replications
        @param offset: offset in the list of messages, keys, and nonces
        @param message: message to use all batch_size
        @param key: key to use all batch_size
        @param nonce: nonce to use for all batch_size
        """
        super().__init__(**kwargs)

        # all important stuff handled by superclass

        #construct GS
        if message is None:
            #should still fit in there.
            from .nonces_2 import NONCES as NONCES_LONG
            #MESSAGES = [m[:self.message_width_in_bytes] for m in MESSAGES]  # trim to message_width_in_bytes to fit our required message length
            self.nonces = NONCES_LONG
 


    def get_wm_type(self) -> str:
        return "GSppReference"   

    
    def __extend_key(self,key,nonce_bytes,final_len=40):
        '''extended key and nonce to 40 bytes. splits into the two parts key and nonce again.'''
        shake = SHAKE128.new()
        shake.update(key+nonce_bytes)
        randomness = shake.read(final_len)

        return randomness[:32],randomness[32:40]

    # get_wm_latents aka generate a latent
    def get_wm_latents(self, custom_message_offset=None, **kwargs) -> typing.Dict[str, any]:
        current_messages = []
        latents_torch = []
        for i in range(self.offset,self.batch_size + self.offset):
            #draw nonce
            nonce_bytes = self.nonces[i]
            nonce_bytes = nonce_bytes[0:self.nonce_len//8]
            #print(nonce_bytes) works, gets transmitted
            key,padded_nonce = self.__extend_key(self.key,nonce_bytes)
            #encode the nonce
            nonce_bits = self.unpack_bytes(nonce_bytes)
            prc_codeword = prc.Encode_gspp(self.key_enc,nonce_bits)

           

            #encode the message with GS
            
             # get the message, key and nonce
            if custom_message_offset is not None:
                message_bytes = self.messages[custom_message_offset]
            else: 
                message_bytes = self.messages[i]  # message_bytes = k
            # remember the message bits in apropreate format
            #TODO
            #current_messages.append(''.join(format(byte, '08b') for byte in message_bytes))
            current_messages.append(message_bytes)
            #hex to bits
            
            #repition code
            bit_array = np.unpackbits(np.frombuffer(message_bytes, dtype=np.uint8))
            sd = np.repeat(bit_array,self.rep_factor)
            
            #encrypt
            cipher = ChaCha20.new(key=key, nonce=padded_nonce)
            m_byte = cipher.encrypt(np.packbits(sd).tobytes())
            m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8))


            # put stuff together
            latent_bits = np.concatenate([prc_codeword,m_bit])

            #sample gaussian
            latent = self._truncSampling(latent_bits)
            latents_torch.append(latent.squeeze(0))

        # finalize
        latents_torch = torch.stack(latents_torch, dim=0)
        latents_torch = latents_torch.to(dtype=self.dtype)

        latents_PIL = torch_to_PIL(latents_torch)

        results_dict = {"zT_torch": latents_torch,
                        "zT_PIL": latents_PIL,
                        "zT": latents_PIL,
                        "message_bits_str_list": current_messages
                        }
    
        return results_dict
    
    def __recover_messages_from_latent(self,latent: torch.Tensor):
        #flatten
        latent = latent.flatten().numpy(force=True)

        #split into two parts
        prc_latent = torch.tensor(latent[:self.prc_len],dtype=torch.float64)
        gs_lantent = latent[self.prc_len:]

        #prc.decode first part
        var = float(1.5)
        reversed_prc = prc.recover_posteriors(prc_latent, variances=var).flatten().cpu()
        decoding_result = prc.Decode_gspp(self.key_dec, reversed_prc)
        if decoding_result is None:
            bit_acc = 0
            message = 0 #TODO: Fix this to a proper invalid values
            return
        else:
            nonce_short = decoding_result
            #rint(np.packbits(nonce_short).tobytes()) works, is the same

        #vibe decode repetition code.

        #decrypt repetition
        key,padded_nonce = self.__extend_key(self.key,np.packbits(nonce_short).tobytes())
        cipher = ChaCha20.new(key=key, nonce=padded_nonce)
        #get the key stream to decide whether to flip signs or not.
        sd_byte = cipher.decrypt(bytearray(self.prc_len//8)) #decrypt a bunch of zeros
        keystream = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))

        #flip sign according to keystream
        gs_lantent *= (1 - 2 * keystream)

        message = np.zeros(self.messagelen)
        #for each bock: add all up and take sign.
        for j,i in enumerate(range(0,len(gs_lantent),self.rep_factor)):
            block = gs_lantent[i:i+self.rep_factor]
            bhat = np.sign(np.sum(block))
            message[j] = bhat

        message = message == 1 # discretise to correct format
        return message


    def get_accuracies(self, latents: typing.Union[torch.Tensor, np.array],**kwargs) -> typing.Dict[str, any]:
        rtn_messages = []
        rtn_accs = []
        for i,latent in enumerate(latents):
            rtn_message = self.__recover_messages_from_latent(latent)
            ori_message = self.messages[i+self.offset]
            bit_acc = self._calculate_bit_accuracy(ori_message,rtn_message)
            rtn_messages.append(rtn_message)
            rtn_accs.append(bit_acc)

        return {
            "accuracies": rtn_accs,
            "bit_accuracies": rtn_accs,
            "message_bits_str_list": rtn_messages
        }