import hashlib
from chromadb import HttpClient
from tqdm import tqdm
from config.config import COLLECTION_PREFIX
# --- CONFIG ---
MODEL_NAMES = ["dinov1_s", "dinov1_b", "dinov1_b8", "dinov2_s", "dinov2_b", "mobilenet_v2"]
BATCH_SIZE = 10_000

# --- Connect to ChromaDB ---
client = HttpClient(host="localhost", port=8010)

def hash_embedding(embedding, decimals=6):
    """Round floats and hash the tuple for uniqueness detection."""
    rounded = tuple(round(x, decimals) for x in embedding)
    return hashlib.sha256(str(rounded).encode()).hexdigest()

def count_unique_embeddings(collection_name):
    collection = client.get_or_create_collection(name=collection_name)
    unique_hashes = set()
    offset = 0

    print(f" Checking: {collection_name}")
    while True:
        results = collection.get(
            limit=BATCH_SIZE,
            offset=offset,
            include=["embeddings"]
        )
        if results["embeddings"] is None or len(results["embeddings"]) == 0:
            break

        for emb in results["embeddings"]:
            emb_hash = hash_embedding(emb)
            unique_hashes.add(emb_hash)

        offset += BATCH_SIZE
        tqdm.write(f" {collection_name} — processed {offset} entries")

    return len(unique_hashes)

# --- Run for all models ---
for name in MODEL_NAMES:
    collection_name = f"{COLLECTION_PREFIX}_{name}"
    num_unique = count_unique_embeddings(collection_name)
    print(f" {collection_name}: {num_unique} unique embeddings")