import os
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, Subset

from torchvision import datasets
from torchvision.transforms import (
    Compose, Resize, CenterCrop, ToTensor, Normalize, Grayscale
)

# =========================================================
# Configuration parameters
# =========================================================
# "CIFAR10", "SVHN", "MedMNIST-PATH", "MedMNIST-BLOOD", "FashionMNIST", "STL10", "QMNIST"
DATASET = "SVHN"
NUM_CLASSES = 10
IMAGES_PER_CLASS = 10000
BATCH_SIZE = 2048
NUM_WORKERS = 4
OUTPUT_DIR = "output_fullclass_dinov3"
DATA_ROOT = "./data"

# ⭐ Local checkpoint and architecture (override via environment variables)
DINOV3_CKPT = os.getenv("DINOV3_CKPT", "data/dinov3_vits16_pretrain_lvd1689m-08c60483.pth")
DINOV3_ARCH = os.getenv("DINOV3_ARCH", "dinov3_vits16")  # e.g., dinov3_vitb16, dinov3_vitl16, dinov3_vits16plus

# =========================================================
# Model loading (DINOv3 / torch.hub + local .pth)
# =========================================================
print("📦 DINOv3 model loading (torch.hub + local ckpt)...")
device = "cuda" if torch.cuda.is_available() else "cpu"
dinov3 = torch.hub.load("facebookresearch/dinov3", DINOV3_ARCH, pretrained=False)  # We load weights manually afterward
ckpt = torch.load(DINOV3_CKPT, map_location="cpu")

# Handle typical key variations (model / state_dict / teacher / student, etc.)
state = None
for k in ("model", "state_dict", "teacher", "student"):
    if isinstance(ckpt, dict) and k in ckpt:
        state = ckpt[k]
        break
if state is None and isinstance(ckpt, dict):
    # Provisionally use the first dictionary
    for v in ckpt.values():
        if isinstance(v, dict):
            state = v
            break
if state is None:
    # Assume it can be used directly as a state_dict
    state = ckpt

# Remove the 'module.' prefix
new_state = {}
for k, v in state.items():
    new_state[k.replace("module.", "")] = v

missing, unexpected = dinov3.load_state_dict(new_state, strict=False)
print(f"✅ ckpt loaded (missing={len(missing)}, unexpected={len(unexpected)})")
dinov3.eval().to(device)
print(f"✅ finish model load: arch={DINOV3_ARCH}, ckpt={os.path.basename(DINOV3_CKPT)}")

# =========================================================
# Image preprocessing (DINOv3: ImageNet standard normalization)
# =========================================================
def build_transform(rgb: bool = True):
    tfs = []
    if not rgb:
        tfs.append(Grayscale(num_output_channels=3))
    tfs.extend([
        Resize(224),
        CenterCrop(224),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return Compose(tfs)

# =========================================================
# Normalize label types to int (same as the original code)
# =========================================================
def to_int_label(y):
    if torch.is_tensor(y):
        y = y.detach().cpu().numpy()
    if isinstance(y, (list, tuple)):
        y = np.array(y)

    if isinstance(y, np.ndarray):
        if y.ndim == 0:
            return int(y)
        if y.ndim == 1:
            if y.size == 1:
                return int(y[0])
            return int(y.argmax())
        return int(y.argmax(axis=-1).item())
    return int(y)

# =========================================================
# Wrapper to make MedMNIST compatible with PyTorch (returns PIL images)
# =========================================================
class MedMNISTWrapper(torch.utils.data.Dataset):
    def __init__(self, base_ds, transform):
        self.base = base_ds
        self.transform = transform
    def __len__(self):
        return len(self.base)
    def __getitem__(self, idx):
        img, target = self.base[idx]
        if torch.is_tensor(img):
            img = img.detach().cpu().numpy()
        if isinstance(img, np.ndarray):
            if img.ndim == 2:
                img = Image.fromarray(img.astype(np.uint8), mode="L")
            elif img.ndim == 3:
                img = Image.fromarray(img.astype(np.uint8))
            else:
                raise ValueError(f"Unexpected MedMNIST image shape: {img.shape}")
        img = self.transform(img)
        # To return PIL images here, the transform omits ToTensor
        target = to_int_label(target)
        return img, target

# =========================================================
# Dataset loader
# =========================================================
def get_datasets(dataset_name: str):
    from torchvision import datasets
    name = dataset_name.strip().lower()

    if name == "cifar10":
        is_rgb = True
        transform = build_transform(rgb=is_rgb)
        train = datasets.CIFAR10(root=DATA_ROOT, train=True, download=True, transform=transform)
        test  = datasets.CIFAR10(root=DATA_ROOT, train=False, download=True, transform=transform)
        return train, test, 10, is_rgb

    elif name == "svhn":
        is_rgb = True
        transform = build_transform(rgb=is_rgb)
        target_tf = (lambda y: int(y) % 10)
        train = datasets.SVHN(root=DATA_ROOT, split="train", download=True, transform=transform, target_transform=target_tf)
        test  = datasets.SVHN(root=DATA_ROOT, split="test",  download=True, transform=transform, target_transform=target_tf)
        return train, test, 10, is_rgb

    elif name in ("fashionmnist", "fashion-mnist", "fashion_mnist"):
        is_rgb = False
        transform = build_transform(rgb=is_rgb)
        train = datasets.FashionMNIST(root=DATA_ROOT, train=True, download=True, transform=transform)
        test  = datasets.FashionMNIST(root=DATA_ROOT, train=False, download=True, transform=transform)
        return train, test, 10, is_rgb

    elif name in ("stl10", "stl-10"):
        is_rgb = True
        transform = build_transform(rgb=is_rgb)
        train = datasets.STL10(root=DATA_ROOT, split="train", download=True, transform=transform)
        test  = datasets.STL10(root=DATA_ROOT, split="test",  download=True, transform=transform)
        return train, test, 10, is_rgb

    elif name == "qmnist":
        is_rgb = False
        transform = build_transform(rgb=is_rgb)
        train = datasets.QMNIST(root=DATA_ROOT, what="train", download=True, transform=transform)
        test  = datasets.QMNIST(root=DATA_ROOT, what="test",  download=True, transform=transform)
        return train, test, 10, is_rgb

    elif name in ("medmnist-path", "medmnist_path", "pathmnist", "medmnist/path"):
        try:
            from medmnist import PathMNIST
        except Exception as e:
            raise ImportError("MedMNIST not found. Please run `pip install medmnist`.") from e
        is_rgb = True
        transform = build_transform(rgb=is_rgb)
        base_train = PathMNIST(root=DATA_ROOT, split="train", download=True, transform=None)
        base_test  = PathMNIST(root=DATA_ROOT, split="test",  download=True, transform=None)
        train = MedMNISTWrapper(base_train, transform)
        test  = MedMNISTWrapper(base_test,  transform)
        return train, test, 9, is_rgb

    elif name in ("medmnist-blood", "medmnist_blood", "bloodmnist", "medmnist/blood"):
        try:
            from medmnist import BloodMNIST
        except Exception as e:
            raise ImportError("MedMNIST not found. Please run `pip install medmnist`.") from e
        is_rgb = True
        transform = build_transform(rgb=is_rgb)
        base_train = BloodMNIST(root=DATA_ROOT, split="train", download=True, transform=None)
        base_test  = BloodMNIST(root=DATA_ROOT, split="test",  download=True, transform=None)
        train = MedMNISTWrapper(base_train, transform)
        test  = MedMNISTWrapper(base_test,  transform)
        return train, test, 8, is_rgb

    else:
        raise ValueError(f"Unsupported DATASET name: {dataset_name}")

# =========================================================
# Class filtering (extract up to max_per_class samples per class)
# =========================================================
def filter_by_classes(dataset, classes, max_per_class=25):
    print(f"🔍 Extracting up to {max_per_class} samples from classes {classes}...")
    class_counts = {cls: 0 for cls in classes}
    indices = []
    for i in tqdm(range(len(dataset))):
        _, label = dataset[i]
        label = to_int_label(label)
        if label in classes and class_counts[label] < max_per_class:
            indices.append(i)
            class_counts[label] += 1
        if all(count >= max_per_class for count in class_counts.values()):
            break
    print(f"✅ Filtering complete: {len(indices)} samples in total")
    return Subset(dataset, indices)

# =========================================================
# Feature extraction (DINOv3 / HF)
# =========================================================
@torch.no_grad()
def extract_features(dataset, batch_size=64, num_workers=2):
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=(device == "cuda"),
        persistent_workers=(num_workers > 0)
    )
    features_all, labels_all = [], []

    print("🔄 Extracting features (DINOv3 / torch.hub)...")
    for imgs, labels in tqdm(loader):
        imgs = imgs.to(device, non_blocking=True)
        out = dinov3.forward_features(imgs)
        # The DINOv3 / FB implementation returns the CLS embedding using the same keys as DINOv2
        feats = out.get("x_norm_clstoken", None)
        if feats is None:
            # Fallback to accommodate implementation differences
            # - Some implementations lack 'x_norm_clstoken', requiring 'last_norm' + CLS extraction
            # - Use the 0th element (CLS) of the final token sequence x
            x = out.get("x_norm_patchtokens", None) or out.get("x", None)
            if x is not None and x.dim() == 3:
                feats = x[:, 0]  # CLS
            else:
                raise RuntimeError("Failed to obtain CLS features from DINOv3 forward_features.")

        features_all.append(feats.detach().cpu())

        # Normalize labels (following the original code)
        if torch.is_tensor(labels):
            if labels.ndim == 2:
                labels = labels.argmax(dim=1) if labels.size(1) > 1 else labels.view(-1)
            else:
                labels = labels.view(-1)
            labels = labels.long()
        else:
            labels = torch.as_tensor([to_int_label(y) for y in labels], dtype=torch.long)
        labels_all.append(labels.cpu())

    features = torch.cat(features_all, dim=0)
    labels = torch.cat(labels_all, dim=0)
    print(f"✅ Feature extraction complete: {features.shape[0]} samples, dim={features.shape[1]}")
    return {"features": features, "labels": labels}

# =========================================================
# (Changing only the extension in main makes it easier to distinguish files)
# =========================================================
def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print(f"\n📁 Loading {DATASET} dataset...")
    train_ds, test_ds, ds_num_classes, is_rgb = get_datasets(DATASET)

    used_num_classes = min(NUM_CLASSES, ds_num_classes)
    target_classes = list(range(used_num_classes))
    print(f"🎯 Target classes: {target_classes} / Total classes: {ds_num_classes} (RGB={is_rgb})")

    tiny_train = filter_by_classes(train_ds, target_classes, max_per_class=IMAGES_PER_CLASS)
    tiny_train_data = extract_features(tiny_train, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    train_output_path = os.path.join(
        OUTPUT_DIR, f"{DATASET.lower()}_train_{used_num_classes}class_{IMAGES_PER_CLASS}_{DINOV3_ARCH}.pt"
    )
    torch.save(tiny_train_data, train_output_path)
    print(f"💾 Saved: {train_output_path}")

    tiny_test = filter_by_classes(test_ds, target_classes, max_per_class=IMAGES_PER_CLASS)
    tiny_test_data = extract_features(tiny_test, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    test_output_path = os.path.join(
        OUTPUT_DIR, f"{DATASET.lower()}_test_{used_num_classes}class_{IMAGES_PER_CLASS}_{DINOV3_ARCH}.pt"
    )
    torch.save(tiny_test_data, test_output_path)
    print(f"💾 Saved: {test_output_path}")

    print(f"\n✅ Data saved: {train_output_path}, {test_output_path}")

if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    main()
