import os
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import random
from torch.utils.data import Subset


def set_test_loader(root, dataset, preprocess, sample_size, batch_size, seed):
    '''
    set OOD loader for ImageNet scale datasets or validation loader
    '''
    if preprocess is None:
        normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                                         std=(0.26862954, 0.26130258, 0.27577711))  # for CLIP
        preprocess = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])
    
    kwargs = {'num_workers': 4, 'pin_memory': True}
    
    # Handle validation data
    if dataset == "imagenet":
        testset = datasets.ImageFolder(os.path.join(root, 'imagenet/images/val'), transform=preprocess)
    # Handle OOD datasets
    elif dataset == 'iNaturalist':
        testset = datasets.ImageFolder(root=os.path.join(root, 'iNaturalist'), transform=preprocess)
    elif dataset == 'SUN':
        testset = datasets.ImageFolder(root=os.path.join(root, 'SUN'), transform=preprocess)
    elif dataset == 'places365':
        testset = datasets.ImageFolder(root=os.path.join(root, 'Places'), transform=preprocess)
    elif dataset == 'Texture':
        testset = datasets.ImageFolder(root=os.path.join(root, 'Texture', 'images'),
                                          transform=preprocess)
    else:
        raise ValueError(f"Invalid dataset: {dataset}")
    # Sample 5000 images from each dataset if available
    testset = sample_dataset(testset, sample_size=sample_size, seed=seed)
    
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                                shuffle=False, **kwargs)
    return testloader


def sample_dataset(dataset, sample_size=5000, seed=42):
    """
    Sample a subset of the dataset with specified size.
    
    Args:
        dataset: PyTorch dataset
        sample_size (int): Number of samples to take (default: 5000)
        seed (int): Random seed for reproducibility
    
    Returns:
        Subset: Sampled dataset subset
    """
    total_size = len(dataset)
    
    # If dataset is smaller than sample_size, return the original dataset
    if total_size <= sample_size:
        print(f"Dataset size ({total_size}) is smaller than or equal to sample size ({sample_size}). Using full dataset.")
        return dataset
    
    # Set random seed for reproducibility
    random.seed(seed)
    
    # Generate random indices for sampling
    indices = random.sample(range(total_size), sample_size)
    
    print(f"Sampling {sample_size} images from dataset of size {total_size}")
    
    # Create subset
    subset = Subset(dataset, indices)
    
    return subset
