import os
import random
from pathlib import Path
from typing import Tuple, Union

import medmnist
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset, random_split
from torchvision import datasets, transforms
from torchvision.datasets.folder import pil_loader

# NUM_WORKERS = 12
NUM_WORKERS = 8


def gaussian_mixture_noise(n_samples: int):
    return np.concatenate(
        [
            np.random.multivariate_normal(
                mean=[-np.pi / 2, np.pi / 2],
                cov=0.1 * np.eye(2),
                size=int(n_samples / 2),
            ),
            np.random.multivariate_normal(
                mean=[np.pi / 2, -np.pi / 2],
                cov=0.1 * np.eye(2),
                size=int(n_samples / 2),
            ),
        ]
    )


def make_toy_data(
    n_normal: int = 900,
    n_labeled_anomaly: int = 20,
    n_unlabeled_anomaly: int = 80,
    is_train=True,
    batch_size: int = 128,
) -> DataLoader:
    X_unlabeled_normal = np.zeros((n_normal, 2))
    for i, x in enumerate(np.linspace(-np.pi, np.pi, n_normal)):
        X_unlabeled_normal[i, 0] = x + np.random.normal(0, 0.1)
        X_unlabeled_normal[i, 1] = 3.0 * (np.sin(x) + np.random.normal(0, 0.2))

    X_unlabeled_anomaly = gaussian_mixture_noise(n_unlabeled_anomaly)
    X_labeled_anomaly = gaussian_mixture_noise(n_labeled_anomaly)

    X = np.concatenate([X_unlabeled_normal, X_unlabeled_anomaly, X_labeled_anomaly])
    if is_train:
        y = np.concatenate(
            [
                np.zeros(shape=n_normal),
                np.zeros(shape=n_unlabeled_anomaly),
                np.ones(shape=n_labeled_anomaly),
            ]
        )
    else:
        y = np.concatenate(
            [
                np.zeros(shape=n_normal),
                np.ones(shape=n_unlabeled_anomaly),
                np.ones(shape=n_labeled_anomaly),
            ]
        )

    dataset = TensorDataset(torch.from_numpy(X.astype(np.float32)), torch.from_numpy(y.astype(np.int32)))
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader


def load(
    name: str,
    is_gray_scale: bool = True,
    batch_size: int = 128,
    normal_class: int = 0,
    unseen_anomaly: int = 9,
    n_train: int = 4500,
    n_valid: int = 500,
    n_unlabeled_normal: int = 4500,
    n_labeled_anomaly: int = 250,
    n_unlabeled_anomaly: int = 250,
) -> tuple[DataLoader, DataLoader, DataLoader, DataLoader, DataLoader]:
    # dataset path
    path = f"datasets/{name}/" if name in ["CIFAR10", "SVHN"] else "datasets/"
    os.makedirs(path, exist_ok=True)

    # transform
    base = [
        transforms.ToTensor(),
        transforms.Resize((32, 32), antialias=True),
    ]
    gray_scaler = [transforms.Grayscale()]
    colorizer = [transforms.Grayscale(3)]
    flattener = [transforms.Lambda(lambda x: torch.flatten(x))]

    if name in ["CIFAR10", "SVHN"]:
        if is_gray_scale:
            transform = transforms.Compose(gray_scaler + base + flattener)
        else:
            transform = transforms.Compose(base)

    else:  # MNIST, FashionMNIST
        if is_gray_scale:
            transform = transforms.Compose(gray_scaler + base + flattener)
        else:
            transform = transforms.Compose(colorizer + base)

    if name == "MNIST":
        train = datasets.MNIST(root=path, download=True, train=True, transform=transform)
        test = datasets.MNIST(root=path, download=True, train=False, transform=transform)
    elif name == "FashionMNIST":
        train = datasets.FashionMNIST(root=path, download=True, train=True, transform=transform)
        test = datasets.FashionMNIST(root=path, download=True, train=False, transform=transform)
    elif name == "CIFAR10":
        train = datasets.CIFAR10(root=path, download=True, train=True, transform=transform)
        test = datasets.CIFAR10(root=path, download=True, train=False, transform=transform)
    else:  # SVHN
        train = datasets.SVHN(root=path, download=True, split="train", transform=transform)
        test = datasets.SVHN(root=path, download=True, split="test", transform=transform)

    # Train
    train_indices = train.targets if name != "SVHN" else train.labels
    if not torch.is_tensor(train_indices):
        train_indices = torch.tensor(train_indices)

    train_normal_indices = (train_indices == normal_class).nonzero().squeeze().tolist()
    train_anomaly_indices = (
        torch.logical_and(train_indices != normal_class, train_indices != unseen_anomaly).nonzero().squeeze().tolist()
    )

    train_normal_bag = random.sample(train_normal_indices, k=n_unlabeled_normal)
    train_anomaly_bag = random.sample(train_anomaly_indices, k=n_labeled_anomaly + n_unlabeled_anomaly)

    train_positive_bag = train_anomaly_bag[:n_labeled_anomaly]
    train_unlabeled_bag = train_normal_bag + train_anomaly_bag[n_labeled_anomaly:]

    for i in train_positive_bag:
        if name != "SVHN":
            train.targets[i] = 1
        else:
            train.labels[i] = 1

    for i in train_unlabeled_bag:
        if name != "SVHN":
            train.targets[i] = 0
        else:
            train.labels[i] = 0

    train_subset = Subset(train, train_positive_bag + train_unlabeled_bag)
    train_subset, valid_subset = random_split(train_subset, [n_train, n_valid])
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    # Test
    test_indices = test.targets if name != "SVHN" else test.labels
    if not torch.is_tensor(test_indices):
        test_indices = torch.tensor(test_indices)

    test_normal_indices = (test_indices == normal_class).nonzero().squeeze().tolist()
    test_unseen_anomaly_indices = (test_indices == unseen_anomaly).nonzero().squeeze().tolist()
    test_seen_anomaly_indices = (
        torch.logical_and(test_indices != normal_class, test_indices != unseen_anomaly).nonzero().squeeze().tolist()
    )

    test_normal_bag = random.sample(test_normal_indices, k=min(len(test_normal_indices), 1000))
    test_unseen_anomaly_bag = random.sample(test_unseen_anomaly_indices, k=min(len(test_unseen_anomaly_indices), 500))
    test_seen_anomaly_bag = random.sample(test_seen_anomaly_indices, k=min(len(test_seen_anomaly_indices), 500))

    for i in test_unseen_anomaly_bag + test_seen_anomaly_bag:
        if name != "SVHN":
            test.targets[i] = 1
        else:
            test.labels[i] = 1

    for i in test_normal_bag:
        if name != "SVHN":
            test.targets[i] = 0
        else:
            test.labels[i] = 0

    test_subset = Subset(test, test_seen_anomaly_bag + test_unseen_anomaly_bag + test_normal_bag)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    test_seen_subset = Subset(test, test_seen_anomaly_bag + test_normal_bag)
    test_seen_loader = DataLoader(test_seen_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    test_unseen_subset = Subset(test, test_unseen_anomaly_bag + test_normal_bag)
    test_unseen_loader = DataLoader(test_unseen_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    return train_loader, valid_loader, test_loader, test_seen_loader, test_unseen_loader


def to_numpy(data_loader: DataLoader) -> tuple[np.ndarray, np.ndarray]:
    data = []
    label = []
    for batch in data_loader:
        data.append(batch[0].numpy())
        label.append(batch[1].numpy())

    return np.vstack(data), np.hstack(label)


def to_tensor(data_loader: DataLoader) -> tuple[torch.Tensor, torch.Tensor]:
    data = []
    label = []
    for batch in data_loader:
        data.append(batch[0])
        label.append(batch[1])

    return torch.cat(data, dim=0), torch.cat(label, dim=0)


def sparse2coarse(targets):
    # https://github.com/ryanchankh/cifar100coarse/blob/master/sparse2coarse.py
    coarse_labels = np.array(
        [
            4,
            1,
            14,
            8,
            0,
            6,
            7,
            7,
            18,
            3,
            3,
            14,
            9,
            18,
            7,
            11,
            3,
            9,
            7,
            11,
            6,
            11,
            5,
            10,
            7,
            6,
            13,
            15,
            3,
            15,
            0,
            11,
            1,
            10,
            12,
            14,
            16,
            9,
            11,
            5,
            5,
            19,
            8,
            8,
            15,
            13,
            14,
            17,
            18,
            10,
            16,
            4,
            17,
            4,
            2,
            0,
            17,
            4,
            18,
            17,
            10,
            3,
            2,
            12,
            12,
            16,
            12,
            1,
            9,
            19,
            2,
            10,
            0,
            1,
            16,
            12,
            9,
            13,
            15,
            13,
            16,
            19,
            2,
            4,
            6,
            19,
            5,
            5,
            8,
            19,
            18,
            1,
            2,
            15,
            6,
            0,
            17,
            8,
            14,
            13,
        ]
    )
    return coarse_labels[targets]


def load_cifar100(
    batch_size: int = 128,
    normal_class: Tuple[int, ...] = (0, 1, 2, 4, 7, 8, 11, 12, 13, 15, 16, 17),
    unseen_anomaly: int = 14,
    n_train: int = 4500,
    n_valid: int = 500,
    n_unlabeled_normal: int = 4500,
    n_labeled_anomaly: int = 250,
    n_unlabeled_anomaly: int = 250,
) -> tuple[DataLoader, DataLoader, DataLoader, DataLoader, DataLoader]:
    # dataset path
    path = "datasets/CIFAR100/"
    os.makedirs(path, exist_ok=True)

    # transform
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32), antialias=True)])

    train = datasets.CIFAR100(root=path, download=True, train=True, transform=transform)
    test = datasets.CIFAR100(root=path, download=True, train=False, transform=transform)

    # Train
    train_indices = torch.tensor(sparse2coarse(train.targets))

    train_normal_indices = torch.tensor([i in normal_class for i in train_indices]).nonzero().squeeze().tolist()

    train_anomaly_indices = (
        torch.tensor([i not in normal_class and i != unseen_anomaly for i in train_indices])
        .nonzero()
        .squeeze()
        .tolist()
    )

    train_normal_bag = random.sample(train_normal_indices, k=n_unlabeled_normal)
    train_anomaly_bag = random.sample(train_anomaly_indices, k=n_labeled_anomaly + n_unlabeled_anomaly)

    train_positive_bag = train_anomaly_bag[:n_labeled_anomaly]
    train_unlabeled_bag = train_normal_bag + train_anomaly_bag[n_labeled_anomaly:]

    for i in train_positive_bag:
        train.targets[i] = 1

    for i in train_unlabeled_bag:
        train.targets[i] = 0

    train_subset = Subset(train, train_positive_bag + train_unlabeled_bag)
    train_subset, valid_subset = random_split(train_subset, [n_train, n_valid])
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    # Test
    test_indices = torch.tensor(sparse2coarse(test.targets))

    test_normal_indices = torch.tensor([i in normal_class for i in test_indices]).nonzero().squeeze().tolist()
    test_unseen_anomaly_indices = torch.tensor([i == unseen_anomaly for i in test_indices]).nonzero().squeeze().tolist()
    test_seen_anomaly_indices = (
        torch.tensor([i not in normal_class and i != unseen_anomaly for i in test_indices]).nonzero().squeeze().tolist()
    )

    test_normal_bag = random.sample(test_normal_indices, k=min(len(test_normal_indices), 1000))
    test_unseen_anomaly_bag = random.sample(test_unseen_anomaly_indices, k=min(len(test_unseen_anomaly_indices), 500))
    test_seen_anomaly_bag = random.sample(test_seen_anomaly_indices, k=min(len(test_seen_anomaly_indices), 500))

    for i in test_unseen_anomaly_bag + test_seen_anomaly_bag:
        test.targets[i] = 1

    for i in test_normal_bag:
        test.targets[i] = 0

    test_subset = Subset(test, test_seen_anomaly_bag + test_unseen_anomaly_bag + test_normal_bag)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    test_seen_subset = Subset(test, test_seen_anomaly_bag + test_normal_bag)
    test_seen_loader = DataLoader(test_seen_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    test_unseen_subset = Subset(test, test_unseen_anomaly_bag + test_normal_bag)
    test_unseen_loader = DataLoader(test_unseen_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    return train_loader, valid_loader, test_loader, test_seen_loader, test_unseen_loader


def load_medical_mnist(
    name: str,
    batch_size: int = 128,
    normal_class: int = 0,
    unseen_anomaly: int = 9,
    n_train: int = 4500,
    n_valid: int = 500,
    n_unlabeled_normal: int = 4500,
    n_labeled_anomaly: int = 250,
    n_unlabeled_anomaly: int = 250,
) -> tuple[DataLoader, DataLoader, DataLoader, DataLoader, DataLoader]:
    # dataset path
    path = "datasets/MedMNIST"
    os.makedirs(path, exist_ok=True)

    options = {
        "PathMNIST": medmnist.PathMNIST,
        # "ChestMNIST": medmnist.ChestMNIST,
        # "DermaMNIST": medmnist.DermaMNIST,
        "OCTMNIST": medmnist.OCTMNIST,
        # "PneumoniaMNIST": medmnist.PneumoniaMNIST,
        # "RetinaMNIST": medmnist.RetinaMNIST,
        # "BreastMNIST": medmnist.BreastMNIST,
        # "BloodMNIST": medmnist.BloodMNIST,
        "TissueMNIST": medmnist.TissueMNIST,
        # "OrganAMNIST": medmnist.OrganAMNIST,
        # "OrganCMNIST": medmnist.OrganCMNIST,
        # "OrganSMNIST": medmnist.OrganSMNIST,
    }

    color_images = ["PathMNIST", "DermaMNIST", "RetinaMNIST", "BloodMNIST"]

    # transform
    base = [
        transforms.ToTensor(),
        transforms.Resize((32, 32), antialias=True),
    ]
    colorizer = [transforms.Grayscale(3)]

    if name in color_images:
        transform = transforms.Compose(base)
    else:
        transform = transforms.Compose(colorizer + base)

    train = options[name](root=path, download=True, split="train", transform=transform)
    test = options[name](root=path, download=True, split="test", transform=transform)

    if name == "ChestMNIST":
        train.labels = np.argmax(train.labels, axis=1)
        test.labels = np.argmax(test.labels, axis=1)
    else:
        train.labels = train.labels.flatten()
        test.labels = test.labels.flatten()

    # Train
    train_indices = torch.tensor(train.labels.squeeze())
    train_normal_indices = torch.eq(train_indices, normal_class).nonzero().squeeze().tolist()
    train_anomaly_indices = (
        torch.logical_and(
            torch.ne(train_indices, normal_class),
            torch.ne(train_indices, unseen_anomaly),
        )
        .nonzero()
        .squeeze()
        .tolist()
    )

    train_normal_bag = random.sample(train_normal_indices, k=n_unlabeled_normal)
    train_anomaly_bag = random.sample(train_anomaly_indices, k=n_labeled_anomaly + n_unlabeled_anomaly)

    train_positive_bag = train_anomaly_bag[:n_labeled_anomaly]
    train_unlabeled_bag = train_normal_bag + train_anomaly_bag[n_labeled_anomaly:]

    for i in train_positive_bag:
        train.labels[i] = 1

    for i in train_unlabeled_bag:
        train.labels[i] = 0

    train_subset = Subset(train, train_positive_bag + train_unlabeled_bag)
    train_subset, valid_subset = random_split(train_subset, [n_train, n_valid])
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    # Test
    test_indices = torch.tensor(test.labels.squeeze())
    test_normal_indices = torch.eq(test_indices, normal_class).nonzero().squeeze().tolist()
    test_unseen_anomaly_indices = torch.eq(test_indices, unseen_anomaly).nonzero().squeeze().tolist()
    test_seen_anomaly_indices = (
        torch.logical_and(torch.ne(test_indices, normal_class), torch.ne(test_indices, unseen_anomaly))
        .nonzero()
        .squeeze()
        .tolist()
    )

    test_normal_bag = random.sample(test_normal_indices, k=min(len(test_normal_indices), 1000))
    test_unseen_anomaly_bag = random.sample(test_unseen_anomaly_indices, k=min(len(test_unseen_anomaly_indices), 500))
    test_seen_anomaly_bag = random.sample(test_seen_anomaly_indices, k=min(len(test_seen_anomaly_indices), 500))

    for i in test_unseen_anomaly_bag + test_seen_anomaly_bag:
        test.labels[i] = 1

    for i in test_normal_bag:
        test.labels[i] = 0

    test_subset = Subset(test, test_seen_anomaly_bag + test_unseen_anomaly_bag + test_normal_bag)
    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    test_seen_subset = Subset(test, test_seen_anomaly_bag + test_normal_bag)
    test_seen_loader = DataLoader(test_seen_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    test_unseen_subset = Subset(test, test_unseen_anomaly_bag + test_normal_bag)
    test_unseen_loader = DataLoader(test_unseen_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    return train_loader, valid_loader, test_loader, test_seen_loader, test_unseen_loader


class ImagePathDataset(Dataset):
    def __init__(self, files: list[Union[str, Path]], labels: list[int], transform=None):
        self.files = files
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        label = self.labels[i]
        image = pil_loader(path)
        if self.transform is not None:
            image = self.transform(image)
        return image, label


def split_mvtec_paths() -> tuple[list[Path], list[Path], list[Path], list[Path]]:
    unseen_anomalies = {
        "bottle": "broken_large",
        "cable": "bent_wire",
        "capsule": "crack",
        "carpet": "color",
        "grid": "bent",
        "hazelnut": "crack",
        "leather": "color",
        "metal_nut": "bent",
        "pill": "color",
        "screw": "manipulated_front",
        "tile": "crack",
        "toothbrush": None,
        "transistor": "bent_lead",
        "wood": "color",
        "zipper": "broken_teeth",
    }

    root = Path("datasets/mvtec_anomaly_detection")

    train_normal_paths = []
    test_normal_paths = []
    seen_anomaly_paths = []
    unseen_anomaly_paths = []

    for category_dir in root.iterdir():
        if not category_dir.is_dir():
            continue

        # normal data
        train_normal_dir = category_dir / "train" / "good"
        train_normal_paths += sorted(train_normal_dir.glob("*.png"))

        test_normal_dir = category_dir / "test" / "good"
        test_normal_paths += sorted(test_normal_dir.glob("*.png"))

        # anomaly data
        test_dir = category_dir / "test"
        anomaly_dirs = [d for d in test_dir.iterdir() if d.name != "good" and d.is_dir()]
        if not anomaly_dirs:
            continue

        excluded_name = unseen_anomalies[category_dir.name]
        for anomaly_dir in anomaly_dirs:
            images = list(anomaly_dir.glob("*.png"))
            if anomaly_dir.name == excluded_name:
                unseen_anomaly_paths += images
            else:
                seen_anomaly_paths += images

    return (
        train_normal_paths,
        test_normal_paths,
        seen_anomaly_paths,
        unseen_anomaly_paths,
    )


def load_mvtec(
    batch_size: int = 128,
) -> tuple[DataLoader, DataLoader, DataLoader, DataLoader, DataLoader]:
    # normal (train): 3629
    # normal (test): 467
    # seen: 1006
    # unseen: 252
    # 425 / (3629 + 425) ~= 0.105

    train_normal_paths, test_normal_paths, seen_anomaly_paths, unseen_anomaly_paths = split_mvtec_paths()
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((224, 224), antialias=True)])

    # sampling train and test seen anomalies
    n_seen_anomaly = 725
    random.shuffle(seen_anomaly_paths)
    train_seen_anomaly_paths = seen_anomaly_paths[:n_seen_anomaly]
    test_seen_anomaly_paths = seen_anomaly_paths[n_seen_anomaly:]

    # train dataset
    n_unlabeled_anomaly = 425
    train_labels = (
        [0] * len(train_normal_paths) + [0] * n_unlabeled_anomaly + [1] * (n_seen_anomaly - n_unlabeled_anomaly)
    )
    train_dataset = ImagePathDataset(
        files=train_normal_paths + train_seen_anomaly_paths,
        labels=train_labels,
        transform=transform,
    )
    n_train = 3500
    n_valid = 854
    train_subset, valid_subset = random_split(train_dataset, [n_train, n_valid])

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)

    # test dataset
    test_labels = [0] * len(test_normal_paths) + [1] * len(test_seen_anomaly_paths) + [1] * len(unseen_anomaly_paths)
    test_dataset = ImagePathDataset(
        files=test_normal_paths + test_seen_anomaly_paths + unseen_anomaly_paths,
        labels=test_labels,
        transform=transform,
    )

    test_seen_labels = [0] * len(test_normal_paths) + [1] * len(test_seen_anomaly_paths)
    test_seen_dataset = ImagePathDataset(
        files=test_normal_paths + test_seen_anomaly_paths,
        labels=test_seen_labels,
        transform=transform,
    )

    test_unseen_labels = [0] * len(test_normal_paths) + [1] * len(unseen_anomaly_paths)
    test_unseen_dataset = ImagePathDataset(
        files=test_normal_paths + unseen_anomaly_paths,
        labels=test_unseen_labels,
        transform=transform,
    )

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    test_seen_loader = DataLoader(test_seen_dataset, batch_size=batch_size, shuffle=True, num_workers=NUM_WORKERS)
    test_unseen_loader = DataLoader(
        test_unseen_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )

    return train_loader, valid_loader, test_loader, test_seen_loader, test_unseen_loader


if __name__ == "__main__":
    # image datasets
    for dataset_name in ["MNIST", "FashionMNIST", "CIFAR10", "SVHN"]:
        ret = load(dataset_name, batch_size=128)
        print(dataset_name)
        print("Train:", len(ret[0].dataset))  # type: ignore
        print("Valid:", len(ret[1].dataset))  # type: ignore
        print("Test (all):", len(ret[2].dataset))  # type: ignore
        print("Test (seen):", len(ret[3].dataset))  # type: ignore
        print("Test (unseen):", len(ret[4].dataset))  # type: ignore

    # CIFAR100
    ret = load_cifar100(batch_size=128)
    print("CIFAR100")
    print("Train:", len(ret[0].dataset))  # type: ignore
    print("Valid:", len(ret[1].dataset))  # type: ignore
    print("Test (all):", len(ret[2].dataset))  # type: ignore
    print("Test (seen):", len(ret[3].dataset))  # type: ignore
    print("Test (unseen):", len(ret[4].dataset))  # type: ignore

    # MedMNIST
    for dataset_name in ["PathMNIST", "ChestMNIST", "OCTMNIST", "TissueMNIST"]:
        ret = load_medical_mnist(dataset_name, batch_size=128)
        print(dataset_name)
        print("Train:", len(ret[0].dataset))  # type: ignore
        print("Valid:", len(ret[1].dataset))  # type: ignore
        print("Test (all):", len(ret[2].dataset))  # type: ignore
        print("Test (seen):", len(ret[3].dataset))  # type: ignore
        print("Test (unseen):", len(ret[4].dataset))  # type: ignore

    # MVTec
    ret = load_mvtec(batch_size=128)
    print("MVTec")
    print("Train:", len(ret[0].dataset))  # type: ignore
    print("Valid:", len(ret[1].dataset))  # type: ignore
    print("Test (all):", len(ret[2].dataset))  # type: ignore
    print("Test (seen):", len(ret[3].dataset))  # type: ignore
    print("Test (unseen):", len(ret[4].dataset))  # type: ignore
