import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from torch_uncertainty.datamodules.classification import CIFAR100DataModule, ImageNetDataModule
from datamodules import \
    (CIFAR10DataModule, Food101DataModule, EuroSATDataModule, Sun397DataModule, 
     OxfordIIITPetDataModule, DTDDataModule, Caltech101DataModule)

if __name__ == "__main__":
    batch_size = 128
    num_workers = 4
    val_tfms = None

    imagenet_root = "data/in1k_torch_uncertainty"
    if not os.path.exists(imagenet_root):
        print("Downloading Imagenet-1k and OOD sets")
        os.makedirs(imagenet_root)
        dm = ImageNetDataModule(
            root=imagenet_root,
            batch_size=batch_size,
            num_workers=num_workers,
            eval_ood=True,
            pin_memory=True,
            persistent_workers=False
        )
        dm.prepare_data()
        dm.setup("test")
    
    cifar100_root = "data/cifar100_torch_uncertainty"
    if not os.path.exists(cifar100_root):
        print("Downloading CIFAR100 and OOD sets")
        os.makedirs(cifar100_root)
        dm = CIFAR100DataModule(
            root=cifar100_root,
            batch_size=batch_size,
            test_transform=val_tfms,
            num_workers=num_workers,
            eval_ood=True,
            pin_memory=True,
            persistent_workers=False
        )
        dm.prepare_data()
        dm.setup("test")
    
    cifar10_root = "data/cifar10"
    if not os.path.exists(cifar10_root):
        CIFAR10DataModule(transform=val_tfms, data_dir=cifar10_root, batch_size=batch_size).prepare_data()

    food101_root = "data/food101"
    if not os.path.exists(food101_root):
        Food101DataModule(transform=val_tfms, data_dir=food101_root, batch_size=batch_size).prepare_data()

    Sun397DataModule(transform=val_tfms, batch_size=batch_size).prepare_data()

    if not os.path.exists("data/oxford-iiit-pet"):
        print("Downloading Oxford-Pets")
        OxfordIIITPetDataModule(transform=val_tfms, data_dir="data", batch_size=batch_size).prepare_data()

    if not os.path.exists("data/dtd"):
        print("Downloading DTD")
        DTDDataModule(transform=val_tfms, data_dir="data", batch_size=batch_size).prepare_data()

    if not os.path.exists("data/eurosat/EuroSAT.zip"):
        print("Downloading Eurosat")
        EuroSATDataModule(transform=val_tfms, data_dir="data", batch_size=batch_size).prepare_data()

    if not os.path.exists("data/caltech101/Caltech101.zip"):
        print("Downloading Caltech101")
        Caltech101DataModule(transform=val_tfms, data_dir="data", batch_size=batch_size).prepare_data()