import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import timm
import chromadb
import numpy as np
from tqdm import tqdm
from config.config import COLLECTION_PREFIX
# --- CONFIGURATION ---
TOP_K = 100
NUM_IMAGES = 5  # Number of synthetic test images
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# --- MODELS: name -> (timm_id, embedding_dim, image_input_size)
MODELS = {
    "dinov1_s": ("vit_small_patch16_224_dino", 384, 224),
    "dinov1_b": ("vit_base_patch16_224_dino", 768, 224),
    "dinov1_b8": ("vit_base_patch8_224_dino", 768, 224),
    "dinov2_s": ("vit_small_patch14_dinov2.lvd142m", 384, 518),
    "dinov2_b": ("vit_base_patch14_dinov2.lvd142m", 768, 518),
    "mobilenet_v2": ("mobilenetv2_100", 960, 224),
}

# --- CHROMA CONNECTION ---

client = chromadb.HttpClient(host="localhost", port=8010)

# --- TRANSFORM ---
def get_transform(image_size):
    return transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

# --- MODEL LOADING ---
def load_model(timm_id):
    model = timm.create_model(timm_id, pretrained=True)
    model.eval().to(DEVICE)
    if hasattr(model, 'head'):
        model.reset_classifier(0)
    elif hasattr(model, 'fc'):
        model.fc = nn.Identity()
    return model

models = {name: load_model(timm_id) for name, (timm_id, _, _) in MODELS.items()}

# --- GENERATE SYNTHETIC TEST IMAGES ---
def generate_random_image(size=(256, 256)):
    array = np.random.randint(0, 256, (size[1], size[0], 3), dtype=np.uint8)
    return Image.fromarray(array)

synthetic_images = [(f"synthetic_{i}", generate_random_image()) for i in range(NUM_IMAGES)]

# --- INFERENCE & QUERY ---
for model_name, (timm_id, _, input_size) in MODELS.items():
    model = models[model_name]
    transform = get_transform(input_size)
    collection = client.get_collection(f"{COLLECTION_PREFIX}_{model_name}")

    print(f"\n🔍 Running search with model: {model_name}")

    for image_name, image in tqdm(synthetic_images, desc=f"Querying {model_name}"):
        try:
            image_tensor = transform(image).unsqueeze(0).to(DEVICE)

            with torch.no_grad():
                embedding = model(image_tensor).squeeze().cpu().tolist()

            results = collection.query(
                query_embeddings=[embedding],
                n_results=TOP_K
            )

            print(f"\nTop {TOP_K} for {image_name} with {model_name}:")
            print(results["ids"][0])  # You can also print distances or metadatas

        except Exception as e:
            print(f" Error processing {image_name} with model {model_name}: {e}")
