import torch
import typing
import argparse
from .wm_provider import WmProvider
from utils.image_utils import torch_to_PIL
import numpy as np
import pickle

from . import prc

#argparse all the params to be determined
parser = argparse.ArgumentParser(add_help=False)

parser.add_argument('--t', default=3, type=int, help="something about the sparsity of LDPC, no idea what it does tbh")
parser.add_argument('--fpr', default=1e-9, type=float, help="desired fpr")
parser.add_argument("--message_length",type=int,default=256,help="message length to be encoded.")


class PrcProvider(WmProvider):

    def __init__(self,
                filepath_keys = None,
                 t: int = 3,
                 fpr:float = 1e-9,
                 message_length: int = 512,
                 offset: int = 0,
                 message = None,
                 **kwargs):
        super().__init__(**kwargs)

        self.t = t
        self.fpr = fpr
        #latent size n
        self.latent_bits = self.num_channels * self.latent_resolution**2
        self.message_length = message_length

        self.offset = offset
        if message is None:
            from .messages_long import MESSAGES
            self.messages = MESSAGES
        else:
            self.messages = message

        if filepath_keys is None:
          #key gen or load
            (encoding_key_ori, decoding_key_ori) = prc.KeyGen(n = self.latent_bits,
                                                            message_length= self.message_length,
                                                            false_positive_rate=fpr, 
                                                            t=t)  # Sample PRC keys
            
            self.key_enc = encoding_key_ori
            self.key_dec = decoding_key_ori

        else:
            with open(filepath_keys,"rb") as f:
                keydict = pickle.load(f)
            self.key_enc = keydict["key_enc"]
            self.key_dec = keydict["key_dec"]

    def dump_keys(self,filepath):
        with open(filepath,"wb") as f:
            keys = {"key_enc":self.key_enc,
                    "key_dec":self.key_dec}
            pickle.dump(keys,f)

            
    def get_wm_type(self) -> str:
        return "PRC"

    def get_wm_latents(self, **kwargs) -> typing.Dict[str, any]:
        latents_torch = []
        current_messages = []

        for i in range(self.offset,self.batch_size + self.offset):
             # get the message, key and nonce
            message_bytes = self.messages[i] 
            #ToDo unpack message bits
            message_bits = self.__unpack_bytes(message_bytes)
            prc_codeword = prc.Encode(self.key_enc,message=message_bits)
            init_latents = prc.sample(prc_codeword).reshape(1, 4, 64, 64).to(self.device)
            ##add to list
            latents_torch.append(init_latents.squeeze())
            current_messages.append(message_bytes)


        # finalize
        latents_torch = torch.stack(latents_torch, dim=0)
        latents_torch = latents_torch.to(dtype=self.dtype)

        latents_PIL = torch_to_PIL(latents_torch)

        results_dict = {"zT_torch": latents_torch,
                        "zT_PIL": latents_PIL,
                        "zT": latents_PIL,
                        "message_bits_str_list": current_messages
                        }
    
        return results_dict
    
    def get_accuracies(self, latents: typing.Union[torch.Tensor, np.array],**kwargs) -> typing.Dict[str, any]:
        var = float(1.5)
        accuracies = list()
        messages_recovered=list()

        for i,latent in enumerate(latents):

            reversed_prc = prc.recover_posteriors(latent.to(torch.float64).flatten().cpu(), 
                                                  variances=var).flatten().cpu()
            #detection_result = prc.Detect(self.key_dec, reversed_prc)
            decoding_result = prc.Decode(self.key_dec, reversed_prc,print_progress=False)
            if decoding_result is None:
                bit_acc = 0
                message = 0 #TODO: Fix this to a proper invalid values
            else:
                message = decoding_result
                ori_message = self.messages[i+self.offset]
                bit_acc = self.__calculate_bit_accuracy(ori_message,message)

            accuracies.append(bit_acc)
            messages_recovered.append(np.packbits(message).tobytes())


        return {
            "accuracies": accuracies,
            "bit_accuracies": accuracies,
            "message_bits_str_list": messages_recovered,
        }
    


    def __calculate_bit_accuracy(self,
                                 db_message: bytes,
                                 extracted_message: np.typing.ArrayLike) -> float:
        """
        Gett bit accuracy between extracted bits and the original message hex

        @param original_message_hex: original message in hex
        @param extracted_message_bits_str: extracted message in bits

        @return: bit accuracy
        """
        db_message = self.__unpack_bytes(db_message)

        # Ensure both binary strings are of the same length, should be the case though.
        min_length = min(len(db_message), len(extracted_message))
        db_message = db_message[:min_length]
        extracted_message = extracted_message[:min_length]
        
        # Calculate bit accuracy
        matching_bits = np.count_nonzero(db_message == extracted_message)
        bit_accuracy = matching_bits / min_length
        
        return bit_accuracy  
    
    def __unpack_bytes(self,bites):
        return np.unpackbits(np.frombuffer(bites, dtype=np.uint8))