import os
import json
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import timm
import chromadb
from config.config import BASE_DIR, COLLECTION_PREFIX, IMAGE_ID_WHITELIST, DATASET_TEST_EXTERNAL, DISTANCES_TEST_EXTERNAL

# === CONFIGURATION ===

MODELS = {
    "mobilenet_v2": ("mobilenetv2_100", 1000, 224),
    "dinov1_s": ("vit_small_patch16_224_dino", 384, 224),
    "dinov1_b": ("vit_base_patch16_224_dino", 768, 224),
    "dinov1_b8": ("vit_base_patch8_224_dino", 768, 224),
    "dinov2_s": ("vit_small_patch14_dinov2.lvd142m", 384, 518),
    "dinov2_b": ("vit_base_patch14_dinov2.lvd142m", 768, 518),
}

DEVICE = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)

# Load image ID whitelist

whitelist_path = IMAGE_ID_WHITELIST
whitelist_df = pd.read_csv(whitelist_path)
whitelist_ids = set(whitelist_df['image_id'].astype(str))

# Define directories

base_dir = BASE_DIR
DATASET_PATH = DATASET_TEST_EXTERNAL
CSV_SAVE_DIR = DISTANCES_TEST_EXTERNAL
os.makedirs(CSV_SAVE_DIR, exist_ok=True)
client = chromadb.HttpClient(host="localhost", port=8010)

def get_transform(image_size):
    return transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

class FlatImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, allowed_ids=None):
        self.root_dir = root_dir
        self.transform = transform
        self.allowed_ids = allowed_ids

        all_image_paths = [
            os.path.join(root_dir, fname)
            for fname in os.listdir(root_dir)
            if fname.lower().endswith((".jpg", ".jpeg", ".png"))
        ]

        image_ids = [os.path.splitext(os.path.basename(p))[0] for p in all_image_paths]

        if self.allowed_ids is not None:
            filtered = [(p, i) for p, i in zip(all_image_paths, image_ids) if i in self.allowed_ids]
            if filtered:
                self.image_paths, self.image_ids = zip(*filtered)
                self.image_paths = list(self.image_paths)
                self.image_ids = list(self.image_ids)
            else:
                self.image_paths, self.image_ids = [], []
        else:
            self.image_paths = all_image_paths
            self.image_ids = image_ids

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, self.image_ids[idx]
        except (OSError, ValueError) as e:
            print(f" Skipping corrupted image: {img_path}")
            return self.__getitem__((idx + 1) % len(self))

def process_model(model_key, model_name, emb_dim, img_size):
    print(f"\n==> Processing model: {model_key}")
    model = timm.create_model(model_name, pretrained=True).to(DEVICE).eval()
    transform = get_transform(img_size)

    dataset = FlatImageDataset(DATASET_PATH, transform=transform, allowed_ids=whitelist_ids)
    print(f"🔢 Number of images to be analyzed by {model_key}: {len(dataset)}")
    if len(dataset) == 0:
        print(f" No matching images found for model {model_key}. Skipping...")
        return

    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    full_collection_name = f"{COLLECTION_PREFIX}_{model_key}"
    collections = [c.name for c in client.list_collections()]
    if full_collection_name not in collections:
        raise ValueError(f"Collection '{full_collection_name}' not found in ChromaDB.")

    collection = client.get_collection(name=full_collection_name)

    print("🔍 Querying cosine neighbors...")
    all_similarities = []
    image_ids = []

    for img_tensor, image_id in tqdm(dataloader, total=len(dataset)):
        img_tensor = img_tensor.to(DEVICE)
        with torch.no_grad():
            emb = model(img_tensor).squeeze().cpu().numpy()
            emb = emb / np.linalg.norm(emb)

        result = collection.query(
            query_embeddings=[emb.tolist()],
            n_results=1000,
            include=["distances"]
        )
        distances = result["distances"][0]
        similarities = 1 - np.array(distances)
        all_similarities.append(similarities)
        image_ids.append(image_id[0])

    df = pd.DataFrame(all_similarities, index=image_ids)
    df.index.name = "image_id"
    df.columns = [str(i) for i in df.columns]

    csv_path = os.path.join(CSV_SAVE_DIR, f"{model_key}_cosine_similarities.csv")
    df.to_csv(csv_path)
    print(f" Saved: {csv_path}")

if __name__ == "__main__":
    for key, (model_name, dim, size) in MODELS.items():
        process_model(key, model_name, dim, size)