"""Extract DINOv2 CLS embeddings from ImageNet."""

import h5py
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

from .config import EmbeddingConfig, DINO_DIMS


def load_dino_model(model_name: str, device: str) -> torch.nn.Module:
    """Load a DINOv2 model from torch hub."""
    model = torch.hub.load("facebookresearch/dinov2", model_name)
    model = model.to(device)
    model.eval()
    return model


def get_transform():
    """Get the standard DINOv2 transform."""
    from torchvision import transforms
    return transforms.Compose([
        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


def get_image_and_label(item: dict) -> tuple:
    """Extract image and label from dataset item, handling different key names."""
    # Handle different image key names
    if "image" in item:
        img = item["image"]
    elif "img" in item:
        img = item["img"]
    else:
        raise KeyError(f"No image key found. Available keys: {list(item.keys())}")

    # Handle different label key names
    if "label" in item:
        label = item["label"]
    elif "fine_label" in item:
        label = item["fine_label"]
    elif "coarse_label" in item:
        label = item["coarse_label"]
    elif "labels" in item:
        label = item["labels"]
    else:
        label = 0  # fallback if no label

    return img, label


def collate_fn(batch, transform):
    """Collate function that applies transform and handles PIL images."""
    images = []
    labels = []
    for item in batch:
        img, label = get_image_and_label(item)
        if img.mode != "RGB":
            img = img.convert("RGB")
        images.append(transform(img))
        labels.append(label)
    return torch.stack(images), torch.tensor(labels)


def extract_embeddings(config: EmbeddingConfig, no_wandb: bool = True) -> None:
    """Extract CLS embeddings from DINOv2 and save to HDF5."""
    import wandb

    if not no_wandb:
        wandb.init(project="icml2026-embeddings", config=vars(config))

    emb_dim = DINO_DIMS[config.model_name]
    print(f"Loading {config.model_name} (dim={emb_dim})...")
    model = load_dino_model(config.model_name, config.device)
    transform = get_transform()

    print(f"Loading dataset {config.dataset} (streaming={config.streaming})...")
    dataset = load_dataset(
        config.dataset,
        split=config.split,
        streaming=config.streaming,
        cache_dir=config.cache_dir,
    )

    if config.streaming:
        # WARNING: Streaming mode uses a small shuffle buffer (10k) which does NOT
        # properly shuffle class-sorted datasets like ImageNet. Use only for local
        # debugging with small samples. For production, use streaming=False.
        print("WARNING: Streaming mode uses small shuffle buffer - class distribution may be biased!")
        print("         Use streaming=False for production runs.")
        dataset = dataset.shuffle(seed=42, buffer_size=10_000)
        if config.max_samples:
            dataset = dataset.take(config.max_samples)
        iterator = iter(dataset)
    else:
        if config.max_samples:
            dataset = dataset.select(range(min(config.max_samples, len(dataset))))

        def collate_wrapper(batch):
            return collate_fn(batch, transform)

        iterator = DataLoader(
            dataset,
            batch_size=config.batch_size,
            num_workers=config.num_workers,
            collate_fn=collate_wrapper,
            pin_memory=True,
            shuffle=True,  # Shuffle to avoid class-ordered storage
        )

    # Determine total samples for HDF5 pre-allocation
    if config.streaming:
        total_samples = config.max_samples or 1_000_000  # estimate for streaming
    else:
        total_samples = len(dataset)

    print(f"Extracting embeddings to {config.output_path}...")
    config.output_path.parent.mkdir(parents=True, exist_ok=True)

    with h5py.File(config.output_path, "w") as f:
        # Store as (d, n) to match Julia column-major convention
        emb_dset = f.create_dataset(
            "embeddings",
            shape=(emb_dim, 0),
            maxshape=(emb_dim, total_samples),
            dtype=np.float32,
            chunks=(emb_dim, min(config.batch_size, 1024)),
        )
        label_dset = f.create_dataset(
            "labels",
            shape=(0,),
            maxshape=(total_samples,),
            dtype=np.int32,
            chunks=(min(config.batch_size, 1024),),
        )
        # Store metadata
        f.attrs["model_name"] = config.model_name
        f.attrs["dataset"] = config.dataset
        f.attrs["embedding_dim"] = emb_dim

        current_idx = 0

        if config.streaming:
            # Streaming mode: batch manually
            batch_images = []
            batch_labels = []
            pbar = tqdm(total=config.max_samples, desc="Extracting")

            for item in iterator:
                img, label = get_image_and_label(item)
                if img.mode != "RGB":
                    img = img.convert("RGB")
                batch_images.append(transform(img))
                batch_labels.append(label)

                if len(batch_images) >= config.batch_size:
                    images = torch.stack(batch_images).to(config.device)
                    labels = torch.tensor(batch_labels)

                    with torch.no_grad():
                        emb = model(images)  # (batch, dim)

                    emb_np = emb.cpu().float().numpy().T  # (dim, batch)
                    labels_np = labels.numpy()

                    new_idx = current_idx + emb_np.shape[1]
                    emb_dset.resize(new_idx, axis=1)
                    label_dset.resize(new_idx, axis=0)
                    emb_dset[:, current_idx:new_idx] = emb_np
                    label_dset[current_idx:new_idx] = labels_np

                    current_idx = new_idx
                    pbar.update(len(batch_images))
                    batch_images = []
                    batch_labels = []

            # Handle remaining
            if batch_images:
                images = torch.stack(batch_images).to(config.device)
                labels = torch.tensor(batch_labels)
                with torch.no_grad():
                    emb = model(images)
                emb_np = emb.cpu().float().numpy().T
                labels_np = labels.numpy()
                new_idx = current_idx + emb_np.shape[1]
                emb_dset.resize(new_idx, axis=1)
                label_dset.resize(new_idx, axis=0)
                emb_dset[:, current_idx:new_idx] = emb_np
                label_dset[current_idx:new_idx] = labels_np
                pbar.update(len(batch_images))
            pbar.close()
        else:
            # DataLoader mode
            for images, labels in tqdm(iterator, desc="Extracting"):
                images = images.to(config.device)

                with torch.no_grad():
                    emb = model(images)  # (batch, dim)

                emb_np = emb.cpu().float().numpy().T  # (dim, batch)
                labels_np = labels.numpy()

                new_idx = current_idx + emb_np.shape[1]
                emb_dset.resize(new_idx, axis=1)
                label_dset.resize(new_idx, axis=0)
                emb_dset[:, current_idx:new_idx] = emb_np
                label_dset[current_idx:new_idx] = labels_np

                current_idx = new_idx

        f.attrs["num_samples"] = current_idx

    print(f"Saved {current_idx} embeddings of dim {emb_dim} to {config.output_path}")

    if not no_wandb:
        wandb.log({"num_samples": current_idx, "embedding_dim": emb_dim})
        wandb.finish()
