# save_as_embeddings_dinov2_local_per_class_facebook.py
# pip install torch torchvision pillow tqdm numpy transformers

import os
import argparse
import numpy as np
import torch
from PIL import Image
from torchvision.datasets import CIFAR100, ImageFolder
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModel

# ---- torch polyfill for older versions (no get_default_device / set_default_device) ----
if not hasattr(torch, "get_default_device"):
    def _get_default_device():
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.get_default_device = _get_default_device
if not hasattr(torch, "set_default_device"):
    def _set_default_device(*args, **kwargs):
        return None
    torch.set_default_device = _set_default_device
# ---------------------------------------------------------------------------------------

HF_MODEL_MAP = {
    "dinov2_vits14": "facebook/dinov2-small",   # 384-dim
    "dinov2_vitb14": "facebook/dinov2-base",    # 768-dim
    "dinov2_vitl14": "facebook/dinov2-large",   # 1024-dim
    "dinov2_vitg14": "facebook/dinov2-giant"    # 1536-dim
}

def get_dinov2_model_and_processor(model_name="dinov2_vitl14"):
    if model_name not in HF_MODEL_MAP:
        raise ValueError(f"Unknown model_name '{model_name}'. "
                         f"Choose one of: {list(HF_MODEL_MAP.keys())}")
    hf_id = HF_MODEL_MAP[model_name]
    print(f"Loading Facebook DINOv2 from Hugging Face: {hf_id}")

    # use_fast=True avoids the warning and minor slowdowns
    processor = AutoImageProcessor.from_pretrained(hf_id, use_fast=True)
    model = AutoModel.from_pretrained(hf_id)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()

    hidden_size = getattr(model.config, "hidden_size", None)
    if hidden_size is not None:
        print(f"Embedding (CLS) dimension: {hidden_size}")
    print(f"Model loaded on device: {device}")
    return model, processor, device

def ensure_rgb(img):
    if not isinstance(img, Image.Image):
        img = Image.fromarray(np.array(img))
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img

def extract_embeddings_pil(batch_pil_images, model, processor, device):
    with torch.no_grad():
        inputs = processor(images=batch_pil_images, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model(**inputs)
        cls_emb = outputs.last_hidden_state[:, 0, :]  # [B, D]
        return cls_emb.detach().cpu().numpy()

# ------------------------- Tiny ImageNet support -------------------------
class TinyImageNetDataset(torch.utils.data.Dataset):
    """
    Expects the standard Tiny ImageNet structure under:
      root/tiny-imagenet-200/
        ├── wnids.txt                      (200 wnids)
        ├── words.txt
        ├── train/
        │   ├── n01443537/
        │   │     ├── images/*.JPEG
        │   │     └── boxes.txt (ignored)
        │   └── ...
        └── val/
            ├── images/*.JPEG
            └── val_annotations.txt  (filename  wnid  x  y  w  h)
    """
    def __init__(self, root, split="train"):
        super().__init__()
        self.base = os.path.join(root, "")
        if not os.path.isdir(self.base):
            raise FileNotFoundError(
                f"Tiny ImageNet not found at {self.base}. "
                f"Place 'tiny-imagenet-200' under data root."
            )
        self.split = split
        self.wnids = self._load_wnids()
        self.class_to_idx = {wnid: i for i, wnid in enumerate(self.wnids)}
        self.samples = self._gather_samples()

    def _load_wnids(self):
        wnids_file = os.path.join(self.base, "wnids.txt")
        if os.path.isfile(wnids_file):
            with open(wnids_file, "r") as f:
                wnids = [line.strip() for line in f if line.strip()]
            if len(wnids) != 200:
                print(f"Warning: wnids.txt has {len(wnids)} entries (expected 200).")
            return wnids

        # Fallback: derive from train subfolders
        train_dir = os.path.join(self.base, "train")
        wnids = sorted([d for d in os.listdir(train_dir)
                        if os.path.isdir(os.path.join(train_dir, d))])
        if len(wnids) != 200:
            print(f"Warning: derived {len(wnids)} classes from train/ (expected 200).")
        return wnids

    def _gather_samples(self):
        samples = []
        if self.split == "train":
            train_dir = os.path.join(self.base, "train")
            # Use ImageFolder to traverse, but remap indices to our class_to_idx (wnids.txt order)
            tmp_ds = ImageFolder(train_dir)  # class_to_idx: folder alphabetical
            for path, tmp_idx in tmp_ds.samples:
                wnid = tmp_ds.classes[tmp_idx]
                idx = self.class_to_idx[wnid]
                samples.append((path, idx))
        elif self.split == "val":
            val_dir = os.path.join(self.base, "val")
            anno_path = os.path.join(val_dir, "val_annotations.txt")
            if not os.path.isfile(anno_path):
                raise FileNotFoundError(f"Missing {anno_path}")
            image_dir = os.path.join(val_dir, "images")
            with open(anno_path, "r") as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) < 2:
                        continue
                    fname, wnid = parts[0], parts[1]
                    if wnid not in self.class_to_idx:
                        # If wnids.txt missing that wnid, add it on the fly
                        self.class_to_idx.setdefault(wnid, len(self.class_to_idx))
                    idx = self.class_to_idx[wnid]
                    samples.append((os.path.join(image_dir, fname), idx))
        else:
            raise ValueError("split must be 'train' or 'val' for TinyImageNet")
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i):
        path, y = self.samples[i]
        img = Image.open(path)
        img = ensure_rgb(img)
        return img, y

    @property
    def num_classes(self):
        return len(self.class_to_idx)

# --------------------- dataset loader wrapper ---------------------
def load_dataset(dataset_name, split, root="./data", download=True):
    dataset_name = dataset_name.lower()
    if dataset_name == "cifar100":
        if split not in ("train", "test"):
            raise ValueError("CIFAR-100 split must be 'train' or 'test'")
        ds = CIFAR100(root=root, train=(split == "train"), download=download)
        num_classes = 100
        ds_name_for_path = "cifar100"
    elif dataset_name in ("tinyimagenet", "tiny-imagenet", "tiny_imagenet", "tinyimg"):
        # Map mode: "train" -> TinyImageNet train, "test" -> TinyImageNet val
        if split == "train":
            tin_split = "train"
        elif split == "test":
            tin_split = "val"
        else:
            raise ValueError("TinyImageNet split must be 'train' or 'test' (test maps to val).")
        ds = TinyImageNetDataset(root=root, split=tin_split)
        num_classes = ds.num_classes
        ds_name_for_path = "tinyimg"
    else:
        raise ValueError("Unknown dataset. Use 'cifar100' or 'tinyimagenet'.")
    return ds, num_classes, ds_name_for_path

# --------------------- embedding over a split ---------------------
def embed_split(dataset_name, split, model_name="dinov2_vitl14", batch_size=32, root="./data"):
    model, processor, device = get_dinov2_model_and_processor(model_name)
    ds, num_classes, ds_name_for_path = load_dataset(dataset_name, split, root=root, download=True)

    N = len(ds)
    labels = np.empty(N, dtype=np.int64)
    embs = None

    print(f"Processing {N} images from {dataset_name}:{split} split...")

    for start in tqdm(range(0, N, batch_size), desc=f"Embedding {dataset_name}:{split}"):
        end_idx = min(start + batch_size, N)
        batch_images, batch_labels = [], []

        for i in range(start, end_idx):
            img, y = ds[i]
            img = ensure_rgb(img)
            batch_images.append(img)
            batch_labels.append(y)

        try:
            embeddings = extract_embeddings_pil(batch_images, model, processor, device)

            if embs is None:
                embedding_dim = embeddings.shape[1]
                embs = np.empty((N, embedding_dim), dtype=np.float32)
                print(f"Initialized embedding array with shape: {embs.shape}")

            actual_batch_end = start + embeddings.shape[0]
            embs[start:actual_batch_end] = embeddings.astype(np.float32)
            labels[start:actual_batch_end] = np.array(batch_labels, dtype=np.int64)

        except Exception as e:
            print(f"Error processing batch starting at {start}: {e}")
            continue

    print(f"Successfully processed {len(embs)} images from {dataset_name}:{split} split")
    return embs, labels, num_classes, ds_name_for_path

# ------------------------------- main -------------------------------
def main():
    ap = argparse.ArgumentParser(
        description="Generate embeddings (per-class npy) using Facebook DINOv2 (Hugging Face)")
    ap.add_argument("--dataset", default="cifar100", choices=["cifar100", "tinyimg"],
                    help="Dataset to embed")
    ap.add_argument("--model-name", default="dinov2_vitl14",
                    choices=list(HF_MODEL_MAP.keys()))
    ap.add_argument("--ss-method", default="dinov2_facebook_hf")
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--mode", choices=["all", "train", "test"], default="all",
                    help="'train'/'test' for CIFAR100; for TinyImageNet, 'test' maps to 'val'")
    ap.add_argument("--batch-size", type=int, default=16)
    ap.add_argument("--base-dir", default="")
    ap.add_argument("--data-root", default="./data")
    args = ap.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # for filenames/paths
    dataset_name_for_path = args.dataset.lower().replace("-", "").replace("_", "")
    model_variant = args.model_name.split('/')[-1]
    save_dir = os.path.join(
        args.base_dir,
        f"representations_trained_{dataset_name_for_path}_{args.ss_method}_{model_variant}_seed_{args.seed}"
    )
    os.makedirs(save_dir, exist_ok=True)
    print(f"Saving results to: {save_dir}")

    splits = ["train", "test"] if args.mode == "all" else [args.mode]

    all_embs, all_labels = [], []
    num_classes = None
    for split in splits:
        print(f"\n=== Processing {args.dataset}:{split} split ===")
        try:
            embs, labels, n_cls, ds_name_for_path = embed_split(
                dataset_name=args.dataset,
                split=split,
                model_name=args.model_name,
                batch_size=args.batch_size,
                root=args.data_root
            )
            if num_classes is None:
                num_classes = n_cls
            else:
                # sanity: ensure consistent class count across splits
                if num_classes != n_cls:
                    print(f"Warning: num_classes mismatch across splits: {num_classes} vs {n_cls}")
            all_embs.append(embs)
            all_labels.append(labels)
        except Exception as e:
            print(f"Error processing {split} split: {e}")
            return 1

    if len(all_embs) > 1:
        embs = np.vstack(all_embs)
        labels = np.concatenate(all_labels)
    else:
        embs = all_embs[0]
        labels = all_labels[0]

    print(f"\nCombined embeddings shape: {embs.shape}")
    print(f"Combined labels shape: {labels.shape}")

    embedding_dim = embs.shape[1]
    saved_classes = 0

    # Save per-class .npy (features only, like your original)
    for cls in range(int(num_classes)):
        class_mask = (labels == cls)
        feats_c = embs[class_mask]
        if len(feats_c) == 0:
            print(f"Warning: No samples found for class {cls}")
            continue

        obj = {"features": feats_c}
        out_path = os.path.join(
            save_dir,
            f"{dataset_name_for_path}_{args.ss_method}_{model_variant}_all_{cls}.npy"
        )
        try:
            np.save(out_path, obj, allow_pickle=True)
            print(f"Saved class {cls:3d}: {feats_c.shape[0]:5d} samples, shape {feats_c.shape} -> {out_path}")
            saved_classes += 1
        except Exception as e:
            print(f"Error saving class {cls}: {e}")

    print(f"\n=== Summary ===")
    print(f"Dataset used: {args.dataset}")
    print(f"Model used: {args.model_name} -> {HF_MODEL_MAP[args.model_name]}")
    print(f"Embedding dimension: {embedding_dim}")
    print(f"Total samples processed: {len(embs)}")
    print(f"Classes saved: {saved_classes}/{num_classes}")
    print(f"Files saved to: {save_dir}")
    return 0

if __name__ == "__main__":
    exit(main())
