import hashlib
from chromadb import HttpClient
from tqdm import tqdm
from config.config import COLLECTION_PREFIX
# --- CONFIG ---
MODEL_NAMES = ["dinov1_s", "dinov1_b", "dinov1_b8", "dinov2_s", "dinov2_b", "mobilenet_v2"]
BATCH_SIZE = 10_000

# --- Connect to ChromaDB ---
client = HttpClient(host="localhost", port=8010)

def hash_embedding(embedding, decimals=16):
    """Round floats and hash the tuple for uniqueness detection."""
    rounded = tuple(round(x, decimals) for x in embedding)
    return hashlib.sha256(str(rounded).encode()).hexdigest()

# --- Global set for tracking unique hashes ---
seen_hashes = set()

def count_duplicates(collection_name):
    collection = client.get_or_create_collection(name=collection_name)
    offset = 0
    duplicate_ids = []

    print(f" Checking: {collection_name}")
    while True:
        results = collection.get(
            limit=BATCH_SIZE,
            offset=offset,
            include=["embeddings"]
        )
        embeddings = results["embeddings"]
        ids = results["ids"]  # always included by default

        if embeddings is None or len(embeddings) == 0:
            break

        for emb, doc_id in zip(embeddings, ids):
            emb_hash = hash_embedding(emb)
            if emb_hash in seen_hashes:
                duplicate_ids.append(doc_id)
            else:
                seen_hashes.add(emb_hash)

        offset += BATCH_SIZE
        tqdm.write(f"  {collection_name} — processed {offset} entries, duplicates so far: {len(duplicate_ids)}")

    return duplicate_ids

# --- Run for all models ---
for name in MODEL_NAMES:
    collection_name = f"{COLLECTION_PREFIX}_{name}"
    duplicate_ids = count_duplicates(collection_name)
    print(f" {collection_name}: {len(duplicate_ids)} duplicate entries would be deleted")
