from copy import deepcopy
import csv
import hashlib
import pathlib
from typing import Tuple, Dict, Union, Sequence, Optional, List

import numpy as np
import torch
import torchvision
from torch.utils.data import DataLoader, ConcatDataset, Dataset, Subset
from torchvision import transforms
import random

from .extra_datasets.patchcamelyon import PatchCamelyon
from .extra_datasets.resisc45_dataset import Resisc45Dataset, get_resisc45_data
from .transforms import assemble_transform
from ..utils import ROOT_DIR, get_logger
from . import custom_transforms

DATAPATH: pathlib.Path = ROOT_DIR / "data"
CSV_KWARGS = {"delimiter": ",", "quotechar": "|", "quoting": csv.QUOTE_MINIMAL}
logger = get_logger("datasets")

NUM_WORKERS = 4
BATCH_SIZE = 32


def get_dataloader(dataset: Dataset, start_fraction: float = 0, stop_fraction: float = 1,
                   shuffle=False, seed=1, **kwargs) -> DataLoader:
    assert 0 <= start_fraction < stop_fraction <= 1

    if start_fraction == 0 and stop_fraction == 1:
        subset = dataset
    else:
        indices = torch.arange(len(dataset))

        if shuffle:
            rng = np.random.RandomState(seed=seed)
            np_indices = indices.detach().numpy()
            rng.shuffle(np_indices)
            indices = torch.tensor(np_indices)

        start_idx = int(start_fraction * len(dataset))
        stop_idx = int(stop_fraction * len(dataset))
        selected_indices = indices[start_idx:stop_idx]

        subset = Subset(dataset, selected_indices)

    loader = DataLoader(subset, shuffle=shuffle, **kwargs)

    return loader


class DualDomainDataset(Dataset):
    def __init__(self, base_dataset, t0: transforms.Compose, t1: transforms.Compose):
        super().__init__()
        self.dataset: Dataset = base_dataset

        self.t0: transforms.Compose = t0
        self.t1: transforms.Compose = t1

        self.to_tensor = transforms.ToTensor()

    def __getitem__(self, index):
        data, target = self.dataset[index]
        return self.t0(data), self.t1(data), target

    def __len__(self):
        return len(self.dataset)


def get_subsets(dataset_name: str, dataset: Dataset, fractions: List, shuffle: bool = True,
                balance: bool = True, batch_size: int = BATCH_SIZE, collate_fn=None) -> Tuple:
    seed = int(hashlib.sha1(dataset_name.encode('utf-8')).hexdigest(), 16) % (2 ** 32)
    logger.info(f"Generating splits using local seed {seed} for dataset {dataset_name}")

    return split_dataset(dataset, fractions, seed=seed, shuffle=shuffle, balance=balance, batch_size=batch_size,
                         collate_fn=collate_fn)


def split_dataset(dataset: Dataset, fractions: List, seed: int = 1, shuffle: bool = True,
                  balance: bool = False, batch_size: int = BATCH_SIZE, collate_fn=None) -> Tuple[Subset]:
    """
    Splits the dataset into multiple subsets according to the fractions specified in the fractions list
    :param dataset: Dataset to split
    :param fractions: each element in the list belongs to the fraction used for the specific split.
    The sum of fractions needs to be lower or equal to one
    :param seed: Seed to use for random shuffling
    :param shuffle: Boolean option weather to shuffle the dataset, before splitting it into subsets
    :param balance: Boolean option weather to balance the number of images per class so each class is included in the
    subset with the same distribution as in the original dataset (no sampling with replacement)
    :param batch_size: Batch size to use for sequential sampling of labels
    :return: Tuple of subsets representing splitted dataset with the same length as the fractions list
    """

    fraction_sum = 0
    for fraction in fractions:
        fraction_sum += fraction
    assert fraction_sum <= 1

    dataloader_sequential_sampler = DataLoader(dataset, batch_size=batch_size,
                                               shuffle=False, num_workers=NUM_WORKERS,
                                               collate_fn=collate_fn)
    if balance:
        labels = None
        labels_start_idx = 0
        for batch_idx, (_, target) in enumerate(dataloader_sequential_sampler):
            if labels is None:
                labels = target.new_zeros(size=(len(dataset),))

            labels_end_idx = labels_start_idx + target.size(0)
            labels[labels_start_idx:labels_end_idx] = target

            labels_start_idx = labels_end_idx

    indices = torch.arange(len(dataset))

    if shuffle:
        rng = np.random.RandomState(seed=seed)
        np_indices = indices.detach().numpy()
        rng.shuffle(np_indices)
        indices = torch.tensor(np_indices)

        if balance:
            np_labels = labels.detach().numpy().copy()
            rng.shuffle(np_labels)
            rng.seed(seed=seed)
            labels_shuffled = torch.tensor(np_labels)

    else:
        if balance:
            labels_shuffled = labels

    selected_indices: List = []
    if balance:
        classes, class_counts = torch.unique(labels_shuffled, sorted=True, return_counts=True)
        for class_idx, class_id in enumerate(classes):
            class_indices = (labels_shuffled == class_id).nonzero().squeeze()

            class_start_idx = 0
            for fraction_idx, fraction in enumerate(fractions):
                class_end_idx = class_start_idx + int(fraction * class_counts[class_idx])

                if len(selected_indices) <= fraction_idx:
                    selected_indices.append(class_indices[class_start_idx:class_end_idx])
                else:
                    selected_indices[fraction_idx] = torch.cat((selected_indices[fraction_idx],
                                                                class_indices[class_start_idx:class_end_idx]), dim=0)
                class_start_idx = class_end_idx

    else:
        start_idx = 0
        for fraction in fractions:
            final_idx = start_idx + int(fraction * len(dataset))
            selected_indices.append(torch.arange(start=start_idx, end=final_idx))
            start_idx = final_idx

    subsets = tuple()
    for fraction_indices in selected_indices:
        fraction_indices_unshuffled = indices[fraction_indices]
        subsets += tuple([Subset(dataset, fraction_indices_unshuffled)])

    return subsets


############################################################################################################
# DATASET GETTERS
############################################################################################################

def get_mnist_dataset(transformer: transforms.Compose) -> Tuple[Dataset, Dataset]:
    train_data = torchvision.datasets.MNIST(str(DATAPATH / "mnist"), train=True, download=True,
                                            transform=transformer[0])
    test_data = torchvision.datasets.MNIST(str(DATAPATH / "mnist"), train=False, download=True,
                                           transform=transformer[1])
    return train_data, test_data


def get_kmnist_dataset(transformer: List[transforms.Compose], dataset_domain: List) -> Tuple[Dataset, Dataset, Dataset]:
    target_transform = None
    if dataset_domain:
        dataset_config = dataset_domain[0]
        n_classes: int = dataset_config.get("n_classes", None)
        if n_classes is not None:
            target_transform = lambda target: round(target / n_classes)

    train_data = torchvision.datasets.KMNIST(str(DATAPATH / "kmnist"), train=True, download=True,
                                            transform=transformer[0], target_transform=target_transform)
    test_data = torchvision.datasets.KMNIST(str(DATAPATH / "kmnist"), train=False,
                                           transform=transformer[1], target_transform=target_transform)
    return train_data, test_data, test_data


def get_cifar10_dataset(transformer: transforms.Compose) -> Tuple[Dataset, Dataset]:
    train_data = torchvision.datasets.CIFAR10(str(DATAPATH / "cifar10"), train=True, download=True,
                                              transform=transformer[0])
    test_data = torchvision.datasets.CIFAR10(str(DATAPATH / "cifar10"), train=False, download=True,
                                             transform=transformer[1])
    return train_data, test_data


def get_cifar100_dataset(transformer: transforms.Compose) -> Tuple[Dataset, Dataset]:
    train_data = torchvision.datasets.CIFAR100(str(DATAPATH / "cifar100"), train=True, download=True,
                                              transform=transformer[0])
    test_data = torchvision.datasets.CIFAR100(str(DATAPATH / "cifar100"), train=False, download=True,
                                             transform=transformer[1])
    return train_data, test_data


def get_fashion_mnist_dataset(all_transforms: transforms.Compose) -> Tuple[Dataset, Dataset]:
    train_data = torchvision.datasets.FashionMNIST(str(DATAPATH / "fashion_data"),
                                                   train=True, download=True, transform=all_transforms[0])
    test_data = torchvision.datasets.FashionMNIST(str(DATAPATH / "fashion_data"),
                                                  train=False, transform=all_transforms[1])
    return train_data, test_data


def get_svhn_dataset(all_transforms: transforms.Compose) -> Tuple[Dataset, Dataset]:
    root = str(DATAPATH / "svhn")
    train_data = torchvision.datasets.SVHN(root, split='train', download=True, transform=all_transforms[0])
    extra_data = torchvision.datasets.SVHN(root, split='extra', download=True, transform=all_transforms[1])
    test_data = torchvision.datasets.SVHN(root, split='test', download=True, transform=all_transforms[1])

    # TODO: use ChainDataset instead of ConcatDataset
    extended_train_data = train_data.ChainDataset([train_data, extra_data])
    # extended_train_data = ConcatDataset((train_data, extra_data))
    return extended_train_data, test_data


def get_imagenet_dataset(all_transforms: transforms.Compose) -> Tuple[Dataset, Dataset]:
    DATAPATH = "/media/sdd/datasets/imagenet/"

    """
   __imagenet_pca is to be used with custom_transforms.Lighting class.
    It is the color augmentation AlexNet style based on ResNet paper.
    Implementation is taken from: https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/fastai_imagenet.py
    """
    __imagenet_pca = {
        'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
        'eigvec': torch.Tensor([
            [-0.5675, 0.7192, 0.4009],
            [-0.5808, -0.0045, -0.8140],
            [-0.5836, -0.6948, 0.4203],
        ])
    }

    train_transforms = transforms.Compose([custom_transforms.RandomResize((256, 480)),
                                           transforms.RandomCrop(224),
                                           transforms.RandomHorizontalFlip(p=0.5),
                                           transforms.ToTensor(),
                                           transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                           custom_transforms.Lighting(0.1, __imagenet_pca['eigval'],
                                                                      __imagenet_pca['eigvec'])
                                           ])

    val_transforms = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                                         ])

    train_data = torchvision.datasets.ImageNet(str(DATAPATH), split='train', transform=train_transforms)
    val_data = torchvision.datasets.ImageNet(str(DATAPATH), split='val', transform=val_transforms)

    test_data = torchvision.datasets.ImageNet(str(DATAPATH), split='val',
                                              transform=val_transforms)  # Test dataset has no GT. Use val as test.

    return train_data, val_data, test_data


def get_resisc45_dataset(all_transforms: List[transforms.Compose]) -> Tuple[Dataset, Dataset]:
    root = str(DATAPATH / "remote_sensing")
    data = get_resisc45_data(root=root)
    initial_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                # transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
    print(f"all transforms: {all_transforms}")

    all_data = Resisc45Dataset(data=data, transform=initial_transform, target_transform=None)
    train_data, val_data = get_subsets(dataset_name="RESISC45", dataset=all_data, fractions=[0.8, 0.2],
                                       shuffle=True, balance=True, batch_size=32)
    train_data = Resisc45Dataset(data=train_data, transform=all_transforms[0],
                                 target_transform=None)
    val_data = Resisc45Dataset(data=val_data, transform=all_transforms[1],
                               target_transform=None)

    return train_data, val_data


def get_emnist_dataset(all_transforms: List[transforms.Compose]) -> Tuple[Dataset, Dataset]:
    root = str(DATAPATH / "letters_data")
    train_data = torchvision.datasets.EMNIST(root, split='letters', train=True, download=True,
                                             transform=all_transforms[0])
    test_data = torchvision.datasets.EMNIST(root, split='letters', train=False, transform=all_transforms[1])
    return train_data, test_data, test_data


def get_syn2real_dataset(all_transforms: transforms.Compose, dataset_domain: List):
    syn2real_root = DATAPATH / "syn2real"

    domain = dataset_domain[0]

    if domain not in ["synthetic", "coco"]:
        raise ValueError

    train_root = syn2real_root / domain / "train"
    val_root = syn2real_root / domain / "val"
    test_root = syn2real_root / domain / "test"

    train_data = torchvision.datasets.ImageFolder(str(train_root), transform=all_transforms[0])
    val_data = torchvision.datasets.ImageFolder(str(val_root), transform=all_transforms[1])
    test_data = torchvision.datasets.ImageFolder(str(test_root), transform=all_transforms[1])

    return train_data, val_data, test_data


def get_raw_datasets(dataset_name: str, transformer: transforms.Compose, dataset_domain: List = []):
    if dataset_name == "MNIST":
        return get_mnist_dataset(transformer)
    elif dataset_name == "CIFAR10":
        return get_cifar10_dataset(transformer)
    elif dataset_name == "CIFAR100":
        return get_cifar100_dataset(transformer)
    elif dataset_name == "FashionMNIST":
        return get_fashion_mnist_dataset(transformer)
    elif dataset_name == "SVHN":
        return get_svhn_dataset(transformer)
    elif dataset_name == "ImageNet":
        return get_imagenet_dataset(transformer)
    elif dataset_name == "EMNIST":
        return get_emnist_dataset(transformer, dataset_domain)
    elif dataset_name == "KMNIST":
        return get_kmnist_dataset(transformer, dataset_domain)
    elif dataset_name == "SYN2REAL":
        return get_syn2real_dataset(transformer, dataset_domain)
    elif dataset_name == "RESISC45":
        return get_resisc45_dataset(transformer, dataset_domain)
    elif dataset_name == "OfficeHome":
        return False
    elif dataset_name == "OfficeCaltech":
        return False
    else:
        raise NotImplementedError


############################################################################################################
############################################################################################################


def assert_domain_match(domain0, domain1):
    for key in domain0:
        if key != "transform":
            assert domain1[key] == domain0[key], "PairwiseDataloader: " \
                                                 "Domains must match in all keys except for <transform>"
        else:
            if domain1[key] == domain0[key]:
                logger.critical("Used a PairwiseDataloader with two exactly matching domains.")


def parse_domain_config(domain_config: Union[Dict, Sequence[Dict]]) -> Tuple[str, int, bool, Tuple[Dataset, ...],
                                                                             transforms.Compose]:
    """
    This function handles the differentiation between a single task config and two task configs.
    If two task configs are passed a DualDomain dataset is wrapped around the raw dataset obtained
    using get_raw_dataset, otherwise the raw datasets are returned.
    :param domain_config:
    :return:
    """
    transformers: Optional[Tuple[transforms.Compose, transforms.Compose]]

    if isinstance(domain_config, dict):
        transformers = assemble_transform(domain_config)
        shuffle = domain_config["shuffle"]
        batch_size = domain_config["batch_size"]
        dataset_name = domain_config["dataset"]
        dataset_domain = domain_config.get("dataset_domain", [])

    elif len(domain_config) == 1:
        transformers = assemble_transform(domain_config[0])
        shuffle = domain_config[0]["shuffle"]
        batch_size = domain_config[0]["batch_size"]
        dataset_name = domain_config[0]["dataset"]
        dataset_domain = domain_config[0].get("dataset_domain", [])

    elif len(domain_config) == 2:
        assert_domain_match(domain_config[0], domain_config[1])
        transformers = None, None
        shuffle = domain_config[0]["shuffle"]
        batch_size = domain_config[0]["batch_size"]
        dataset_name = domain_config[0]["dataset"]
        dataset_domain = domain_config[0].get("dataset_domain", [])
    else:
        raise ValueError("You can pass a maximum of two domain configs to create a dual domain dataloader")

    data_augmentation = domain_config.get("data_augmentation", False)
    if data_augmentation:
        logger.info("Implementing data augmentation")
        # if data augmentation is specified as an option, we only apply the basic, necessary transformations
        # all other transformations are later only applied to the train dataset using the data_augment_transformer
        raw_transformer = transformers[1]
        data_augment_transformer = transformers[0]
    else:
        raw_transformer = transformers[0]
        data_augment_transformer = raw_transformer

    datasets = get_raw_datasets(dataset_name, [data_augment_transformer, raw_transformer], dataset_domain)
    target_transform_specs: Dict = domain_config.get("target_transform", {})
    logger.info(f"Target transform specs: {target_transform_specs}")
    logger.info(f"Domain config specs: {domain_config}")
    random_labels = domain_config.get("target_transform", {}).get("random_labels", {}).get("state", False)
    permute_labels = domain_config.get("target_transform", {}).get("permute_labels", {}).get("state", False)

    if random_labels:
        for dataset in datasets:
            random.seed(1)
            dataset.targets = list(np.random.permutation(dataset.targets))

    elif permute_labels:
        random.seed(1)
        permuted_labels = np.random.permutation(max(datasets[0].targets) + 1)
        for dataset in datasets:
            for i in range(len(dataset.targets)):
                dataset.targets[i] = permuted_labels[dataset.targets[i]]

    if not isinstance(domain_config, dict):
        t0 = assemble_transform(domain_config[0])
        t1 = assemble_transform(domain_config[1])
        datasets = (DualDomainDataset(dataset, t0, t1) for dataset in datasets)

    return dataset_name, batch_size, shuffle, datasets, data_augment_transformer


def get_torchvision_dataloaders(task_config: Union[Dict, Sequence[Dict]], val_fraction: float, num_workers: int,
                                val_from_train: bool = True, balance: bool = True) \
        -> Tuple[DataLoader, DataLoader, DataLoader]:
    assert val_fraction < 1

    if val_from_train is True:
        dataset_name, batch_size, shuffle, (train_data, test_data), \
        data_augment_transformer = parse_domain_config(task_config)

        train_subset, val_subset = get_subsets(dataset_name, train_data, [1 - val_fraction, val_fraction], shuffle,
                                               batch_size=batch_size, balance=balance)
    else:
        dataset_name, batch_size, shuffle, (train_data, val_data, test_data), \
        data_augment_transformer = parse_domain_config(task_config)

        train_subset, _ = get_subsets(dataset_name, train_data, [1 - val_fraction, val_fraction], shuffle,
                                      balance=balance)
        val_subset, _ = get_subsets(dataset_name, val_data, [1 - val_fraction, val_fraction], shuffle, balance=balance)

    dataloaders = create_dataloaders(train_subset, val_subset, test_data, batch_size, shuffle, num_workers,
                                     data_augment_transformer)
    return dataloaders


def create_dataloaders(train_dataset: Dataset, val_dataset: Dataset, test_dataset: Dataset,
                       batch_size, shuffle: bool, num_workers: int,
                       data_augment_transformer: transforms.Compose = None,
                       collate_fn=None) -> Tuple[DataLoader, DataLoader, DataLoader]:
    if data_augment_transformer is not None:
        train_dataset = deepcopy(train_dataset)

        def recursively_apply_transform(dataset: Dataset, transform: transforms.Compose):
            if hasattr(dataset, "transforms"):
                # final layer of dataset
                dataset.transforms = torchvision.datasets.vision.StandardTransform(transform)
            elif hasattr(dataset, "datasets"):
                # for ConcatDataset
                for subset in dataset.datasets:
                    recursively_apply_transform(subset, transform)
            elif hasattr(dataset, "dataset"):
                # for subsets
                recursively_apply_transform(dataset.dataset, transform)
            else:
                raise NotImplementedError

        recursively_apply_transform(train_dataset, data_augment_transformer)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
                              collate_fn=collate_fn)
    valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
                              collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
                             collate_fn=collate_fn)
    return train_loader, valid_loader, test_loader


############################################################################################################
# Dataloader creator functions for all types of datasets
############################################################################################################

def get_mnist_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    MNIST dataloader with (28, 28) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_kmnist_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    KMNIST dataloader with (28, 28) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1, val_from_train=False)


def get_cifar10_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Creates CIFAR10 dataloaders with (32, 32) sized images. Hard coded split between train and testset.
    1.0 is 100% training set.
    TODO: add normalization as last transform (dataset dependent)
    TODO: see https://github.com/kuangliu/pytorch-cifar/blob/master/main.py
    TODO: transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    :param task_config:
    :param all_transforms:
    :return:
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_cifar100_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Creates CIFAR100 dataloaders with (32, 32) sized images. Hard coded split between train and testset.
    1.0 is 100% training set.
    :param task_config:
    :param all_transforms:
    :return:
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_resisc45_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    RESISC45 dataloader with (224, 224) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_camelyon_dataloaders(domain_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Oxford Pets dataset
    """

    logger.info(f"domain_config: {domain_config}")

    train_source_dataset = PatchCamelyon(path=str(DATAPATH / "patchcamelyon/camelyon"), mode='train', augment=True)
    val_source_dataset = PatchCamelyon(path=str(DATAPATH / "patchcamelyon/camelyon"), mode='valid')
    test_source_dataset = PatchCamelyon(path=str(DATAPATH / "patchcamelyon/camelyon"), mode='test')

    train_source_loader = get_dataloader(train_source_dataset,
                                         shuffle=True, batch_size=128, num_workers=4,
                                         drop_last=True)
    val_source_loader = get_dataloader(val_source_dataset,
                                         shuffle=False, batch_size=128, num_workers=4,
                                         drop_last=True)
    test_source_loader = get_dataloader(test_source_dataset,
                                         shuffle=False, batch_size=128, num_workers=4,
                                         drop_last=True)

    return train_source_loader, val_source_loader, test_source_loader


def get_fashion_mnist_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[
    DataLoader, DataLoader, DataLoader]:
    """
    FashionMNIST dataloader with (28, 28) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_svhn_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    SVHN dataloader with (32, 32) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_imagenet_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    ImageNet dataloader with (224, 224) sized images after applying data augmentation.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.0, val_from_train=False,
                                       balance=False)


def get_domainnet_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Domainnet dataloader with (224, 224) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1, balance=True)


def get_emnist_dataloaders(task_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Letters MNIST (EMNIST) dataloader with (28, 28) sized images.
    """
    return get_torchvision_dataloaders(task_config, num_workers=NUM_WORKERS, val_fraction=0.1)


def get_syn2real_dataloaders(domain_config: Union[Dict, Sequence[Dict]]) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Syn2Real dataset sampled with (64, 64) sized images.
    """
    dataset_name, batch_size, shuffle, (train_data, val_data, test_data), \
    data_augment_transformer = parse_domain_config(domain_config)

    return create_dataloaders(train_data, val_data, test_data, batch_size, shuffle, NUM_WORKERS,
                              data_augment_transformer)



