import torch
import numpy as np
from torchvision import transforms
from collections import defaultdict
from tqdm import tqdm
import clip

def extract_clip_features_from_subset(
    dataset_subset,
    batch_size=64,
    model_name="ViT-B/32",
    device=None
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model, _ = clip.load(model_name, device=device, download_root="./CLIP")
    model.eval()

    preprocess = transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    features_by_class = defaultdict(list)

    for i in tqdm(range(0, len(dataset_subset), batch_size), desc="Extracting CLIP features"):
        batch = [dataset_subset[j] for j in range(i, min(i + batch_size, len(dataset_subset)))]

        images = [preprocess(img.convert("RGB")).unsqueeze(0) for img, _ in batch]
        labels = [label for _, label in batch]

        images_tensor = torch.cat(images).to(device)

        with torch.no_grad():
            feats = model.encode_image(images_tensor).cpu().numpy()

        for feat, label in zip(feats, labels):
            features_by_class[str(label)].append(feat)

    features_by_class = {k: np.stack(v) for k, v in features_by_class.items()}
    return features_by_class

def extract_dino_features_from_subset(
    dataset_subset,
    model_name='dinov2_vitb14',
    batch_size=64,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = torch.hub.load('facebookresearch/dinov2', model_name).to(device).eval()

    preprocess = transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    features_by_class = defaultdict(list)

    for i in tqdm(range(0, len(dataset_subset), batch_size), desc="Extracting DINOv2 features"):
        # batch = dataset_subset[i:i + batch_size]
        batch = [dataset_subset[j] for j in range(i, min(i + batch_size, len(dataset_subset)))]

        images = [preprocess(img).unsqueeze(0) for img, _ in batch]
        labels = [label for _, label in batch]
        
        images_tensor = torch.cat(images).to(device)

        with torch.no_grad():
            feats = model(images_tensor).cpu().numpy()

        for feat, label in zip(feats, labels):
            features_by_class[str(label)].append(feat)

    features_by_class = {k: np.stack(v) for k, v in features_by_class.items()}
    return features_by_class
