from torchvision import datasets, transforms
import torch
import os
import pdb


def load_data(data_folder, batch_size, train, num_workers=0, **kwargs):
    domain = os.path.basename(os.path.normpath(data_folder))
    # pdb.set_trace()
    if domain not in ['mnist', 'usps']:
        transform = {
            'train': transforms.Compose(
                [transforms.Resize([256, 256]),
                 transforms.RandomCrop(224),
                 transforms.RandomHorizontalFlip(),
                 transforms.ToTensor(),
                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])]),
            'test': transforms.Compose(
                [transforms.Resize([224, 224]),
                 transforms.ToTensor(),
                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])])
        }
        data = datasets.ImageFolder(root=data_folder, transform=transform['train' if train else 'test'])
        n_class = len(data.classes)
    else:
        data = get_mnist_usps_datasets(data_folder, to_rgb=True)
        n_class = 10

    data_loader = get_data_loader(data, batch_size=batch_size,
                                  shuffle=True if train else False,
                                  num_workers=num_workers, **kwargs, drop_last=True if train else False)
    return data_loader, n_class


def get_data_loader(dataset, batch_size, shuffle=True, drop_last=False, num_workers=0, infinite_data_loader=False,
                    **kwargs):
    if not infinite_data_loader:
        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
                                           num_workers=num_workers, **kwargs)
    else:
        return InfiniteDataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
                                  num_workers=num_workers, **kwargs)


def get_mnist_usps_datasets(data_folder, resize_to=28, to_rgb=True):
    dir_name = os.path.dirname(os.path.normpath(data_folder))
    domain = os.path.basename(os.path.normpath(data_folder))

    transform_list = [transforms.Resize((resize_to, resize_to))]

    if to_rgb:
        transform_list.append(transforms.Grayscale(num_output_channels=3))
        transform_list.append(transforms.ToTensor())
        transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
    else:
        transform_list.append(transforms.Grayscale(num_output_channels=1))
        transform_list.append(transforms.ToTensor())
        transform_list.append(transforms.Normalize(mean=(0.5,), std=(0.5,)))

    transform = transforms.Compose(transform_list)

    if domain == "mnist":
        mnist_usps_dataset = datasets.MNIST(root=dir_name, train=True, download=True, transform=transform)
    elif domain == "usps":
        mnist_usps_dataset = datasets.USPS(root=dir_name, train=True, download=True, transform=transform)

    return mnist_usps_dataset


class _InfiniteSampler(torch.utils.data.Sampler):
    """Wraps another Sampler to yield an infinite stream."""

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            for batch in self.sampler:
                yield batch


class InfiniteDataLoader:
    def __init__(self, dataset, batch_size, shuffle=True, drop_last=False, num_workers=0, weights=None, **kwargs):
        if weights is not None:
            sampler = torch.utils.data.WeightedRandomSampler(weights,
                                                             replacement=False,
                                                             num_samples=batch_size)
        else:
            sampler = torch.utils.data.RandomSampler(dataset,
                                                     replacement=False)

        batch_sampler = torch.utils.data.BatchSampler(
            sampler,
            batch_size=batch_size,
            drop_last=drop_last)

        self._infinite_iterator = iter(torch.utils.data.DataLoader(
            dataset,
            num_workers=num_workers,
            batch_sampler=_InfiniteSampler(batch_sampler)
        ))

    def __iter__(self):
        while True:
            yield next(self._infinite_iterator)

    def __len__(self):
        return 0  # Always return 0