import os

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import (
    CIFAR10,
    CIFAR100,
    FGVCAircraft,
    Flowers102,
    Food101,
    ImageFolder,
    OxfordIIITPet,
    StanfordCars,
    ImageNet,
    SVHN
)
from torchvision.datasets.vision import VisionDataset
import pandas as pd
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_file_from_google_drive
from datasets import load_dataset

ViT_B_16_MEAN = (0.485, 0.456, 0.406)
ViT_B_16_STD = (0.229, 0.224, 0.225)

def build_dataset(
    dataset: str = "cifar100",
    root: str = "./data",
    size: int = 224,
    batch_size: int = 256,
    workers: int = 8,
):
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(size, interpolation=3),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=ViT_B_16_MEAN, std=ViT_B_16_STD)])
    transform_val = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=ViT_B_16_MEAN, std=ViT_B_16_STD)
        ])
    
    def apply_trans_train(examples):
            # Assuming 'image' is the key for image data in your dataset
            examples['image'] = [transform_train(img) for img in examples['image']]
            return examples
    
    def apply_trans_test(examples):
            # Assuming 'image' is the key for image data in your dataset
            examples['image'] = [transform_val(img) for img in examples['image']]
            return examples
    
    if dataset == 'cifar100':
        train_ds = CIFAR100(root, train=True, transform=transform_train, download=True)
        test_ds = CIFAR100(root, train=False, transform=transform_val)
        num_classes = 100
    elif dataset == 'food101':
        train_ds = Food101(root, split="train", transform=transform_train, download=True)
        test_ds = Food101(root, split="test", transform=transform_val)
        num_classes = 101
    elif dataset == 'flowers102':
        train_ds = load_dataset("dpdl-benchmark/oxford_flowers102", split="train")
        train_ds.set_transform(apply_trans_train)
        test_ds = load_dataset("dpdl-benchmark/oxford_flowers102", split="test")
        test_ds.set_transform(apply_trans_test)
        num_classes = 102
    elif dataset == 'svhn':
        train_ds = SVHN(root, split="train", transform=transform_train, download=True)
        test_ds = SVHN(root, split="test", transform=transform_val, download=True)
        num_classes = 10

    val_ds = ImageNet(f"{root}/imagenet-1k", split='val', transform=transform_val)

    trainLoader = DataLoader(train_ds, 
                            batch_size=batch_size, 
                            shuffle=True, 
                            drop_last=True, 
                            pin_memory=True,
                            num_workers=workers)
        
    valLoader = DataLoader(val_ds,
                            batch_size=2048,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=32)
    
    testLoader = DataLoader(test_ds,
                            batch_size=2048,
                            shuffle=False,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=32)
    
    return trainLoader, valLoader, testLoader, num_classes 