import os
import json
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import timm
import chromadb
from config.config import BASE_DIR, DATASET_TEST_INTERNAL, DISTANCES_TEST_INTERNAL, COLLECTION_PREFIX
# === CONFIGURATION ===


MODELS = {
    "mobilenet_v2": ("mobilenetv2_100", 1000, 224),
    "dinov1_s": ("vit_small_patch16_224_dino", 384, 224),
    "dinov1_b": ("vit_base_patch16_224_dino", 768, 224),
    "dinov1_b8": ("vit_base_patch8_224_dino", 768, 224),
    "dinov2_s": ("vit_small_patch14_dinov2.lvd142m", 384, 518),
    "dinov2_b": ("vit_base_patch14_dinov2.lvd142m", 768, 518),
}

DEVICE = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)



# Define directories
base_dir = BASE_DIR
DATASET_PATH = DATASET_TEST_INTERNAL
CSV_SAVE_DIR = DISTANCES_TEST_INTERNAL
os.makedirs(CSV_SAVE_DIR, exist_ok=True)
client = chromadb.HttpClient(host="localhost", port=8010)

def get_transform(image_size):
    return transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

class FlatImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = [
            os.path.join(root_dir, fname)
            for fname in os.listdir(root_dir)
            if fname.lower().endswith((".jpg", ".jpeg", ".png"))
        ]
        self.image_ids = [os.path.splitext(os.path.basename(p))[0] for p in self.image_paths]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, self.image_ids[idx]
        except (OSError, ValueError) as e:
            # Skip corrupted image by recursively calling the next index
            print(f" Skipping corrupted image: {img_path}")
            return self.__getitem__((idx + 1) % len(self))

def process_model(model_key, model_name, emb_dim, img_size):
    print(f"\n==> Processing model: {model_key}")
    model = timm.create_model(model_name, pretrained=True).to(DEVICE).eval()
    transform = get_transform(img_size)

    dataset = FlatImageDataset(DATASET_PATH, transform=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    full_collection_name = f"{COLLECTION_PREFIX}_{model_key}"
    collections = [c.name for c in client.list_collections()]
    if full_collection_name not in collections:
        raise ValueError(f"Collection '{full_collection_name}' not found in ChromaDB.")

    collection = client.get_collection(name=full_collection_name)

    print(" Querying cosine neighbors...")
    all_similarities = []
    image_ids = []

    for img_tensor, image_id in tqdm(dataloader, total=len(dataset)):
        img_tensor = img_tensor.to(DEVICE)
        with torch.no_grad():
            emb = model(img_tensor).squeeze().cpu().numpy()
            emb = emb / np.linalg.norm(emb)

        result = collection.query(
            query_embeddings=[emb.tolist()],
            n_results=1000,
            include=["distances"]
        )
        distances = result["distances"][0]
        similarities = 1 - np.array(distances)
        all_similarities.append(similarities)
        image_ids.append(image_id[0])

    df = pd.DataFrame(all_similarities, index=image_ids)
    df.index.name = "image_id"
    df.columns = [str(i) for i in df.columns]

    csv_path = os.path.join(CSV_SAVE_DIR, f"{model_key}_cosine_similarities.csv")
    df.to_csv(csv_path)
    print(f" Saved: {csv_path}")

if __name__ == "__main__":
    for key, (model_name, dim, size) in MODELS.items():
        process_model(key, model_name, dim, size)
