#import stuff
import torch
import typing
import argparse
from .gs_official_provider import GsReferernceProvider
from Crypto.Cipher import ChaCha20
from scipy.stats import binom

import numpy as np

parser = argparse.ArgumentParser(add_help=False)

#Todo import the gs_reference_parser as well, otherwise it doesn't work
parser.add_argument("--randomness_check",default=False,action=argparse.BooleanOptionalAction, help="whether to check for each bin if the content might be random or not by performing a statistical test. disable with no-XXX")
parser.add_argument("--direct_lookup",default=False,action=argparse.BooleanOptionalAction, help="Shall we use the nonce index to directly lookup the corresponding message (simulating a data base access)? Disable with no-XXX")
parser.add_argument("--optimised_lookup",default=False,action=argparse.BooleanOptionalAction, help="Shall we use the nonce index to directly lookup the corresponding message and simulate the db access faster?")
parser.add_argument("--db_lookup",default=False,action=argparse.BooleanOptionalAction, help="actually use sqlite as a database...")
parser.add_argument("--db_path",type=str,default="./nonceANDmessage.db")
class GsOptimisedProvider(GsReferernceProvider):
    """
    Original by https://github.com/bsmhmmlf/Gaussian-Shading/ and heavily modified for debugging and getting access to internals.
    """

    def __init__(self,
                 randomness_check: bool = True,
                 direct_lookup: bool = True,
                 optimised_lookup: bool = True,
                 db_lookup: bool = False,
                 db_path: str = "nonceANDmessage.db",
                    **kwargs):
        """
        This provider uses a fixed list of keys, messages, and nonces for the watermarking process to eliminate every possibliy of false negatives due to wrong seeds.
        Generate more with "generate_secrets.py" as needed.

        These are the params for the the watermark:
        - message_width_in_bytes = Num users to distinguish
        - num_replications = strength of error correction

        message_width_in_bytes * 8 * num_replications  == num_channels * latent_resolution * latent_resolution
        Example:
        -> 32 (message_width_in_bytes) * 8 * 64 (num_replication)  = 16384 = 4 (num_channels) * 64 (latent_resolution) * 64 (latent_resolution)

        @param randomness_check: check if decrypted latent is random or not
        @param direct_lookup: simulate to know, which message belongs to which nonce
        
        GS General Params
        @param message_width_in_bytes: width of the message in bytes
        @param num_replications: num_replications
        @param offset: offset in the list of messages, keys, and nonces
        @param message: message to use all batch_size
        @param key: key to use all batch_size
        @param nonce: nonce to use for all batch_size

        it's type depends on the specified behaviour.
        0: GS reference
        1: only randomness check
        2: only direct lookup
        3: both
        4: optimised direct lookup
        8: db lookup
        """
        super().__init__(**kwargs)

        self.randomness_check = randomness_check
        self.direct_lookup = direct_lookup
        self.optimised_lookup = optimised_lookup
        self.db_lookup = db_lookup

        if self.optimised_lookup:
            self.nonce_db = [(nonce,mes) for nonce,mes in zip(self.nonces,self.messages_long)]
            print("optimised lookup")

        if self.db_lookup:
            from .nonce_sql_db import NonceMessageDB
            self.nonce_db = NonceMessageDB(db_path=db_path)
            self.nonce_db.open_connection()

    def get_wm_type(self) -> str:
        #bit magic for the type
        index = (self.db_lookup << 3) | (self.optimised_lookup <<2) | ((self.direct_lookup << 1) | self.randomness_check)
        return "GSoptimised"+str(index)
    
    def __decode_repetition_with_sanitycheck(self,watermark_sd,beta_error=0.2,total_fpr = 0.1):
        '''works for bins of size 64 only so far.
        flag a bin as nonrandom if we are in the outlying 20% (power of the test 20%, default choice).
        Flags the entire string as nonrandom with fpr 0.1'''
        bin_size=self.num_replications
        #80% power is a good default choice for a test.
        beta_error = 0.2
        lower = binom.ppf(beta_error / 2, bin_size, 0.5, loc=0)
        upper = binom.ppf(1-(beta_error / 2), bin_size, 0.5, loc=0)
        #we achieve a 0.01 false discovery rate with this choice of parameters
        #one sided binomial(p=beta_error, n = message bit length), get a s.t. P(X<threshtotal)<1-fpr
        threshold_total = binom.ppf(1-total_fpr,self.message_width_in_bits,beta_error,loc=0)


        #bin stuff as before according to reference implementation
        ch_stride = 4 // 1
        hw_stride = 64 // 8
        ch_list = [ch_stride] * 1
        hw_list = [hw_stride] * 8
        split_dim1 = torch.cat(torch.split(watermark_sd, tuple(ch_list), dim=1), dim=0)
        split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0)
        split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0)

        decoded_m = torch.sum(split_dim3, dim=0).clone()

        bins = decoded_m.flatten().cpu().numpy() # convert to numpy because not good at torch

        # Vectorized comparison
        outside_mask = (bins < lower) | (bins > upper) #check if count is in center region
        nonrandom_bins = np.sum(outside_mask) #count how many blocks are actually good

        if nonrandom_bins < threshold_total:
            return False,None
 
        decoded_m[decoded_m <= 0.5*bin_size] = 0
        decoded_m[decoded_m > 0.5*bin_size] = 1
        return True,decoded_m
    


        # decoded_m = torch.sum(split_dim3, dim=0).clone()/bin_size

        # bins = decoded_m.flatten().cpu().numpy() # convert to numpy because not good at torch
        # #check if each bin is good enough
        # bins = np.abs(bins- 0.5)/0.25
        # #central limit theorem, but in a quirky version: goal: nonrandom if 3 std out.
        # bins = (bins*(2*np.sqrt(bin_size))) > 3
        # nonrandom_bins = np.count_nonzero(bins)/self.message_width_in_bits
        # #check if we have enough good bins
        # if nonrandom_bins < threshold_total:
        #     return False,None
        
        # decoded_m[decoded_m <= 0.5] = 0
        # decoded_m[decoded_m > 0.5] = 1
        # return True,decoded_m
    
    def __recover_messages_from_latents(self,latents: torch.Tensor) -> typing.List[typing.Union[float, bytes]]:
        '''contains Jonas idea with looking for non randomniss as well as Erwins idea of checking just one message with "Database".
        '''
        rtn_accuracies = list()
        rtn_messages = list()

        for latent in latents:
            # tokenise to integer
            latent_bits = (latent > 0).int()
            latent_bits = latent_bits.flatten().cpu().numpy()


            #decrypt
            if self.db_lookup:
                perfect_message, perfect_acc = self.__retriever_withdb(latent_bits)
            elif self.optimised_lookup:
                perfect_match, perfect_acc = self.__retriever_dbopt(latent_bits)
            else:
                perfect_match = 0
                perfect_acc = 0
    
                for i,nonce in enumerate(self.nonces):
                    cipher = ChaCha20.new(key=self.key, nonce=nonce)
                    sd_byte = cipher.decrypt(np.packbits(latent_bits).tobytes())
                    sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
                    sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, 64, 64).to(torch.uint8)

                    #decode
                    if self.randomness_check:
                        non_random,reversed_message = self.__decode_repetition_with_sanitycheck(sd_tensor)
                    else:
                        non_random = True
                        reversed_message = self._decode_repetition(sd_tensor)

                    #match
                    #if bar code is kinf of intact
                    if not non_random:
                        continue
                    else:
                        reversed_message=reversed_message.flatten().cpu().numpy()
                        if self.direct_lookup:
                            #directly lookup the message. If it fits, we're done
                            corresponding_message = self.messages_long[i]
                            best_acc = super()._calculate_bit_accuracy(corresponding_message,reversed_message)
                            best_match = i
                        else:
                            best_match = 0
                            best_acc = 0
                            #we iterate over the messages. if we find something above threshold, we break. Otherwise, we just return the best so far.
                            for i,mes in enumerate(self.messages_short):
                                bit_acc = super()._calculate_bit_accuracy(mes,reversed_message)
                                if bit_acc > best_acc:
                                    best_match = i
                                    best_acc = bit_acc
                                    if best_acc > self.bit_accuracy_threshold:
                                        break
                        if best_acc > self.bit_accuracy_threshold:
                            perfect_match = best_match    
                            perfect_acc = best_acc    
                            break
            
            rtn_accuracies.append(perfect_acc)
            if self.direct_lookup or self.optimised_lookup:
                rtn_messages.append(self.messages_long[perfect_match]) # buffer message somewhere? or hope that the os does that
            elif self.db_lookup:
                rtn_messages.append(perfect_message)
            else:
                rtn_messages.append(self.messages_short[perfect_match]) # buffer message somewhere? or hope that the os does that

        return rtn_accuracies,rtn_messages
    

    def __retriever_dbopt(self,latent_bits):
        perfect_match = 0
        perfect_acc = 0
        for i,noncetup in enumerate(self.nonce_db):
            nonce,corresponding_message = noncetup
            cipher = ChaCha20.new(key=self.key, nonce=nonce)
            sd_byte = cipher.decrypt(np.packbits(latent_bits).tobytes())
            sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
            sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, 64, 64).to(torch.uint8)

            #decode
            if self.randomness_check:
                non_random,reversed_message = self.__decode_repetition_with_sanitycheck(sd_tensor)
            else:
                non_random = True
                reversed_message = self._decode_repetition(sd_tensor)

            #match
            #if bar code is kinf of intact
            if not non_random:
                continue
            else:
                reversed_message=reversed_message.flatten().cpu().numpy()
                best_acc = super()._calculate_bit_accuracy(corresponding_message,reversed_message)
                best_match = i
                if best_acc > self.bit_accuracy_threshold:
                    perfect_match = best_match    
                    perfect_acc = best_acc    
                    break
        return perfect_match, perfect_acc

    def __retriever_withdb(self,latent_bits):
        perfect_message = None
        perfect_acc = 0
        for nonce, corresponding_message in self.nonce_db.iterate_all():
            cipher = ChaCha20.new(key=self.key, nonce=nonce)
            sd_byte = cipher.decrypt(np.packbits(latent_bits).tobytes())
            sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
            sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, 64, 64).to(torch.uint8)

            #decode
            if self.randomness_check:
                non_random,reversed_message = self.__decode_repetition_with_sanitycheck(sd_tensor)
            else:
                non_random = True
                reversed_message = self._decode_repetition(sd_tensor)

            #match
            #if bar code is kinf of intact
            if not non_random:
                continue
            else:
                reversed_message=reversed_message.flatten().cpu().numpy()
                best_acc = super()._calculate_bit_accuracy(corresponding_message,reversed_message)
                if best_acc > self.bit_accuracy_threshold:
                    perfect_message = corresponding_message    
                    perfect_acc = best_acc    
                    break
        return perfect_message, perfect_acc

    def get_accuracies(self, latents: typing.Union[torch.Tensor, np.array],**kwargs) -> typing.Dict[str, any]:
        rtn_accuracies,rtn_messages = self.__recover_messages_from_latents(latents)
        
        return {
            "accuracies": rtn_accuracies,
            "bit_accuracies": rtn_accuracies,
            "message_bits_str_list": rtn_messages
        }
    
    def detect(self, latents,fpr =0.1):
        #run decryption 
        results = {}
        results["detect"] = list()
        results["bitacc"] = list()
        results["bitacc_rev"] = list()

        if self.offset_debug is not None:
            nonce = self.nonces[self.offset_debug]

        for latent in latents:
            # tokenise to integer
            latent_bits = (latent > 0).int()
            latent_bits = latent_bits.flatten().cpu().numpy()
            cipher = ChaCha20.new(key=self.key, nonce=nonce)
            sd_byte = cipher.decrypt(np.packbits(latent_bits).tobytes())
            sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
            sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, 64, 64).to(torch.uint8)

            #decode
            non_random,reversed_message = self.__decode_repetition_with_sanitycheck(sd_tensor,total_fpr=fpr)
            results["detect"].append(non_random)
            if reversed_message is None:
                results["bitacc"].append(np.nan)
            else:
                mret = reversed_message.flatten().cpu().numpy()
                results["bitacc"].append(super()._calculate_bit_accuracy(self.messages_long[self.offset_debug],mret))
                results["bitacc_rev"].append(self._calculate_bit_accuracy(self.messages_long[self.offset_debug],~np.array(mret,dtype=bool)))
        return results