import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.datasets.utils import download_and_extract_archive
from Dataloader_funcs.DL_registry import *


def getImageNetteDataloaders(root_dir, transform=transforms.ToTensor(), batch_size=32, num_workers=4, valid_transform=transforms.ToTensor(), subset=-1):
    # Download the Imagenette dataset (160px images)
    url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"
    data_dir = root_dir + "/imagenette"

    # Download and extract the dataset if it's not already present
    if not os.path.exists(data_dir):
        download_and_extract_archive(url, data_dir)

    # Create the ImageFolder dataset
    train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'imagenette2-160/train'), transform=transform)
    valid_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'imagenette2-160/val'), transform=valid_transform)

    # Create DataLoader for batching
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_dataloader, valid_dataloader

@register_DL('ImageNette')
def ImageNette(root, batch_size, num_workers):
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),        # Randomly crop and resize the image to 224x224
        transforms.RandomHorizontalFlip(),         # Randomly flip the image horizontally
        transforms.ColorJitter(brightness=0.2,    # Randomly change the brightness
                            contrast=0.2,      # Randomly change the contrast
                            saturation=0.2,     # Randomly change the saturation
                            hue=0.1),          # Randomly change the hue
        transforms.ToTensor(),                    # Convert the image to a tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize with ImageNet mean and std
                            std=[0.229, 0.224, 0.225]),
    ])
    valid_transform=transforms.ToTensor()
    return getImageNetteDataloaders(root, transform=transforms.ToTensor(), batch_size=batch_size,num_workers=num_workers, valid_transform=valid_transform)