#create as db with one table of (nonce,users) with index on nonces
#import sys
#sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')

import os
import torch
from transformers import AutoFeatureExtractor, AutoModel
from PIL import Image
import chromadb
import tqdm

#nonces and messanges
from utils.wm.messages_long import MESSAGES as MESSAGES_LONG
from utils.wm.nonces_2 import NONCES

#pipes
from utils.wm.wm_utils import WmProviders
from utils.prompt_utils import get_huggingface_list

from utils.pipe import pipe_utils
from utils.wm.gs_official_provider import parser as gs_ref_parser


import argparse

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#dino db layout: metadate: nonce as hex, message as hex
#population procedure: first add 1000 pregenerated images, then add another 9k images from prompts but for nonces / messages



parser = argparse.ArgumentParser(description="populate dino db")
parser.add_argument("--image_folder",type=str,default="./images/GSreference/")
parser.add_argument("--collection_name",type=str,default="images")
parser.add_argument("--dino_model",type=str,default="facebook/dino-vits16")
parser.add_argument("--persist_directory",type=str,default="dinodb_storage")

#image generation
parser.add_argument("--prompt_dataset", type=str, default="alfredplpl/anime-with-caption-cc0")

# model
parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-2-1-base")
parser.add_argument("--scheduler", type=str, default="DDIM")

# experiments settings
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--resolution", type=int, default=512)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--num_inversion_steps", type=int, default=50)
parser.add_argument("--guidance_scale", type=float, default=7.5)
parser.add_argument("--experiment",type=str,default="top",choices=["top","misses","random","batch_create"])
parser.add_argument("--filepath",type=str,default="./data/noisy/")



args, unknown_args = parser.parse_known_args()
print(args)
dino_model = args.dino_model

###setup dino stuff
# 1. Load pretrained DINO model (ViT-S/16)
extractor = AutoFeatureExtractor.from_pretrained(dino_model)
model = AutoModel.from_pretrained(dino_model)
model.eval()

# 2. Define a feature extractor
def extract_dino_features(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = extractor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        features = outputs.last_hidden_state[:, 0]  # CLS token
        features = torch.nn.functional.normalize(features, dim=1)
    return features.squeeze().cpu().numpy()  # (384,)

# 3. Setup persistent ChromaDB client
client = chromadb.PersistentClient(path=args.persist_directory)

# 4. Get or create the collection
collection = client.get_or_create_collection(name=args.collection_name)

# 5. Index your image dataset (only add new images)
image_files = [f for f in os.listdir(args.image_folder) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]

# Avoid duplicates by checking existing IDs
existing_ids = set(collection.get()["ids"])

for idx, filename in tqdm.tqdm(enumerate(image_files)):
    img_id = f"img-{filename}"
    idx = int(filename.split(".")[0])
    if img_id in existing_ids:
        continue  # Skip if already in DB
    path = os.path.join(args.image_folder, filename)
    vec = extract_dino_features(path)
    #prepare metadata
    nonce = NONCES[idx].hex()
    message = MESSAGES_LONG[idx].hex()
    collection.add(
        documents=[filename],
        embeddings=[vec.tolist()],
        ids=[img_id],
        metadatas=[{"nonce":nonce,"message":message}]
    )

