#import stuff
import torch
import typing
import argparse
from .gs_official_provider import GsReferernceProvider
from Crypto.Cipher import ChaCha20

#dino stuff
# import sys
# sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')

import chromadb
from transformers import AutoFeatureExtractor, AutoModel
from PIL import Image


import numpy as np

parser = argparse.ArgumentParser(add_help=False)

#Todo import the gs_reference_parser as well, otherwise it doesn't work
parser.add_argument("--collection_name",type=str,default="images")
parser.add_argument("--dino_model",type=str,default="facebook/dino-vits16")
parser.add_argument("--persist_directory",type=str,default="dinodb_storage")
parser.add_argument("--topkretrieval",type=int,default=100)


class GsChromaProvider(GsReferernceProvider):
    """
    Original by https://github.com/bsmhmmlf/Gaussian-Shading/ and heavily modified for debugging and getting access to internals.
    """

    def __init__(self,
                    dino_model="facebook/dino-vits16", 
                    persist_directory="dinodb_storage",
                    collection_name="images",
                    topkretrieval = 50,
                    **kwargs):
        """
        This provider makes use of a chromadb and dino scores to find similar images. Needs a populated dino db upfront with a path to it.

        This provider uses a fixed list of keys, messages, and nonces for the watermarking process to eliminate every possibliy of false negatives due to wrong seeds.
        Generate more with "generate_secrets.py" as needed.

        These are the params for the the watermark:
        - message_width_in_bytes = Num users to distinguish
        - num_replications = strength of error correction

        message_width_in_bytes * 8 * num_replications  == num_channels * latent_resolution * latent_resolution
        Example:
        -> 32 (message_width_in_bytes) * 8 * 64 (num_replication)  = 16384 = 4 (num_channels) * 64 (latent_resolution) * 64 (latent_resolution)

               
        GS General Params
        @param message_width_in_bytes: width of the message in bytes
        @param num_replications: num_replications
        @param offset: offset in the list of messages, keys, and nonces
        @param message: message to use all batch_size
        @param key: key to use all batch_size
        @param nonce: nonce to use for all batch_size

        """
        super().__init__(**kwargs)


        self.topkretrieval = topkretrieval

        #init dino db
        self.client = chromadb.PersistentClient(path=persist_directory)
        self.collection = self.client.get_or_create_collection(name=collection_name)
        #print("number of items in collection",len(self.collection.get()["ids"]))
        self.extractor = AutoFeatureExtractor.from_pretrained(dino_model)
        self.model = AutoModel.from_pretrained(dino_model)
        self.model.eval()

    def get_wm_type(self) -> str:
        return "GSChroma"
    
    def __extract_dino_features(self,image:Image):
        inputs = self.extractor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
            features = outputs.last_hidden_state[:, 0]  # CLS token
            features = torch.nn.functional.normalize(features, dim=1)
        return features.squeeze().cpu().numpy()  # (384,)
    
    def __recover_message_from_latent(self,latent: torch.Tensor, nonce:bytes, dbmessage:bytes) -> typing.List[typing.Union[float, bytes]]:
        '''look up nonce and message in database and see if it worked.        '''
        latent_bits = (latent > 0).int()
        latent_bits = latent_bits.flatten().cpu().numpy()

        cipher = ChaCha20.new(key=self.key, nonce=nonce)
        sd_byte = cipher.decrypt(np.packbits(latent_bits).tobytes())
        sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
        sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, 64, 64).to(torch.uint8)
    
        reversed_message = self._decode_repetition(sd_tensor)
        reversed_message=reversed_message.flatten().cpu().numpy()
        
        acc = super()._calculate_bit_accuracy(dbmessage,reversed_message)
        
        return acc,reversed_message
    
    def get_accuracies(self, latents: typing.Union[torch.Tensor, np.array],images: Image) -> typing.Dict[str, any]:
        rtn_accuracies = list()
        rtn_messages = list()
        rtn_ranks = list()
        rtn_debug_ids = list()
        for latent,image in zip(latents,images):
            #sort images by dino
            dinorep = self.__extract_dino_features(image)
            dinoresults = self.collection.query(
                query_embeddings=[dinorep.tolist()],
                n_results=self.topkretrieval
            )
            #for image in dino results
            #print(dinoresults.keys())

            debug_ids = []

            BEST_MATCH = False
            last_acc = 0
            last_mes = None
            last_rank = 0
            for rank, (id, metadata) in enumerate(zip(dinoresults["ids"][0], dinoresults["metadatas"][0])):
                hex_data = metadata.get('nonce', '')
                debug_ids.append(id)
                nonce_bytes = bytes.fromhex(hex_data) if hex_data else None
                hex_data = metadata.get('message', '')
                message_bytes = bytes.fromhex(hex_data) if hex_data else None
                #verify nonce
                acc,messageretrieved = self.__recover_message_from_latent(latent,nonce_bytes,message_bytes)
                #break if found
                if BEST_MATCH and acc > last_acc:#self.bit_accuracy_threshold:
                    last_acc = acc
                    last_mes = messageretrieved
                    last_rank = rank
                if (not BEST_MATCH) and acc > self.bit_accuracy_threshold:
                    last_acc = acc
                    last_mes = messageretrieved
                    last_rank = rank
                    break
            rtn_debug_ids.append(debug_ids)
            rtn_accuracies.append(last_acc)
            rtn_messages.append(last_mes)
            rtn_ranks.append(last_rank)
                #    break

        return {
            "accuracies": rtn_accuracies,
            "bit_accuracies": rtn_accuracies,
            "message_bits_str_list": rtn_messages,
            "ranks":rtn_ranks,
            "debug_ids":rtn_debug_ids
        }