import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets



def get_dataloader(dataset_name, batch_size, transform, root="./dataset", is_train=True):
    if dataset_name == 'CIFAR10':
        dataset = datasets.CIFAR10(root=root, train=is_train, transform=transform, download=True)
    elif dataset_name == 'STL10':
        dataset = datasets.STL10(root=root, split='train' if is_train else 'test', transform=transform, download=True)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

def extract_embeddings(loader, model, device):
    print("Extracting embeddings...")
    embeddings, images, labels = [], [], []
    with torch.no_grad():
        for batch_images, batch_labels in loader:
            batch_images = batch_images.to(device)
            embeddings.append(model.encode_images(batch_images).cpu())
            images.append(batch_images.cpu())
            labels.append(batch_labels.cpu())
    return torch.cat(images), torch.cat(labels), torch.cat(embeddings)
