import torch
import torch.nn as nn
from torchvision import transforms
from datasets import load_dataset
from tqdm import tqdm
import chromadb
import timm
from itertools import islice
import warnings
import numpy as np
import hashlib
import time
from config.config import COLLECTION_PREFIX
warnings.filterwarnings("ignore", message="Mapping deprecated model name")

# --- CONFIGURATION ---
IMAGE_LIMIT = None  # Set to e.g. 1000 for testing
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
BATCH_SIZE = 50
DETECT_CHUNK_SIZE = 10000

MODELS = {
    "dinov1_s": ("vit_small_patch16_224_dino", 384, 224),
    "dinov1_b": ("vit_base_patch16_224_dino", 768, 224),
    "dinov1_b8": ("vit_base_patch8_224_dino", 768, 224),
    "dinov2_s": ("vit_small_patch14_dinov2.lvd142m", 384, 518),
    "dinov2_b": ("vit_base_patch14_dinov2.lvd142m", 768, 518),
    "mobilenet_v2": ("mobilenetv2_100", 960, 224),
}

print("🔌 Connecting to Chroma HTTP server (ClickHouse backend)...")
client = chromadb.HttpClient(host="localhost", port=8010)
print("✅ Connected")

def get_collection(model_name):
    return client.get_or_create_collection(name=f"{COLLECTION_PREFIX}_{model_name}")

def detect_resume_index(model_name):
    collection = get_collection(model_name)
    max_id = -1
    offset = 0
    try:
        while True:
            results = collection.get(limit=DETECT_CHUNK_SIZE, offset=offset, include=["metadatas"])
            if not results["metadatas"]:
                break
            ids = [m["image_id"] for m in results["metadatas"] if "image_id" in m]
            if ids:
                max_id = max(max_id, max(ids))
            offset += DETECT_CHUNK_SIZE
            print(f"📦 {model_name}: checked {offset} entries")
    except Exception as e:
        print(f"⚠️ Error detecting resume index for {model_name}: {e}")
    return max_id + 1

print("🔍 Detecting resume index per model ...")
resume_indices = {name: detect_resume_index(name) for name in MODELS}
max_resume_index = max(resume_indices.values())
print("📌 Resume indices:", resume_indices)
print(f"⏭️ Will resume streaming from index: {max_resume_index}")

def get_transform(image_size):
    return transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

def load_model(timm_id):
    print("🔄 Loading model:", timm_id)
    model = timm.create_model(timm_id, pretrained=True)
    model.eval().to(DEVICE)
    if hasattr(model, 'head'):
        model.reset_classifier(0)
    elif hasattr(model, 'fc'):
        model.fc = nn.Identity()
    return model

models = {name: load_model(timm_id) for name, (timm_id, _, _) in MODELS.items()}

print(" Connecting to dataset ...")
dataset = load_dataset("evanarlian/imagenet_1k_resized_256", split="train", streaming=True)
stream = islice(dataset, max_resume_index + 10, None)

### for the full code comment the following
#FIRST_N = 1300
#resume_indices = {name: 0 for name in MODELS}  # ignore resume; start fresh
#max_resume_index = 0                            # start enumerating at 0
#stream = islice(dataset, 0, FIRST_N)
# end of commenting block


batch_buffers = {name: {"ids": [], "embeddings": [], "metadatas": []} for name in MODELS}

# Image repeat detection using hash
last_image_hash = None
def image_hash(tensor):
    return hashlib.md5(tensor.numpy().tobytes()).hexdigest()

# --- EMBEDDING LOOP ---
for i, row in enumerate(tqdm(stream, desc=f"Processing from {max_resume_index}"), start=max_resume_index):
    try:
        transform_tensor = get_transform(224)(row["image"])  # Default size for hash
        current_hash = image_hash(transform_tensor)

        if current_hash == last_image_hash:
            tqdm.write(f"  Repeated image detected at index {i}. Pausing for 60 seconds...")
            time.sleep(60)
            continue

        last_image_hash = current_hash

        for name, (timm_id, dim, image_size) in MODELS.items():
            if i < resume_indices[name]:
                continue

            transform = get_transform(image_size)
            image_tensor = transform(row["image"])
            image = image_tensor.unsqueeze(0).to(DEVICE)

            with torch.no_grad():
                embedding = models[name](image).squeeze().cpu().numpy()
                embedding = embedding / np.linalg.norm(embedding)
                embedding = embedding.tolist()

            buf = batch_buffers[name]
            buf["ids"].append(f"{name}_{i}")
            buf["embeddings"].append(embedding)
            buf["metadatas"].append({
                "model": name,
                "image_id": i,
                "label": row["label"]
            })

            if len(buf["ids"]) >= BATCH_SIZE:
                get_collection(name).add(**buf)
                batch_buffers[name] = {"ids": [], "embeddings": [], "metadatas": []}

        tqdm.write(f" Finished image {i} for all models")

    except Exception as e:
        tqdm.write(f" Error at image {i}: {e}")

    if IMAGE_LIMIT and i >= IMAGE_LIMIT:
        break

print("🧹 Flushing remaining batches...")
for name, buf in batch_buffers.items():
    if buf["ids"]:
        get_collection(name).add(**buf)

print("\n All embeddings saved to ClickHouse-backed ChromaDB.")
