import os
import torch
import torchvision
import torchvision.transforms as transforms
from . import lmdb_dataset
from . import torchvision_extension as transforms_extension
from .prefetch_data import fast_collate
from imagenet_train_cfg import cfg as config


class ImageNet12(object):

    def __init__(self, trainFolder, testFolder, num_workers=8, pin_memory=True,
                 size_images=224, scaled_size=256, type_of_data_augmentation='rand_scale',
                 data_config=None):

        self.data_config = data_config
        self.trainFolder = trainFolder
        self.testFolder = testFolder
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.patch_dataset = self.data_config.patch_dataset

        # images will be rescaled to match this size
        if not isinstance(size_images, int):
            raise ValueError('size_images must be an int. It will be scaled to a square image')
        self.size_images = size_images
        self.scaled_size = scaled_size

        type_of_data_augmentation = type_of_data_augmentation.lower()
        if type_of_data_augmentation not in ('rand_scale', 'random_sized'):
            raise ValueError('type_of_data_augmentation must be either rand-scale or random-sized')
        self.type_of_data_augmentation = type_of_data_augmentation

    def _getTransformList(self, aug_type):

        assert aug_type in ['rand_scale', 'random_sized', 'week_train', 'validation']
        list_of_transforms = []

        if aug_type == 'validation':
            list_of_transforms.append(transforms.Resize(self.scaled_size))
            list_of_transforms.append(transforms.CenterCrop(self.size_images))

        else:
            if aug_type == 'rand_scale':
                list_of_transforms.append(transforms_extension.RandomScale(256, 480))
                list_of_transforms.append(transforms.RandomCrop(self.size_images))
                list_of_transforms.append(transforms.RandomHorizontalFlip())

            elif aug_type == 'random_sized':
                list_of_transforms.append(transforms.RandomResizedCrop(self.size_images,
                                                                       scale=(
                                                                       self.data_config.random_sized.min_scale, 1.0)))
                list_of_transforms.append(transforms.RandomHorizontalFlip())

            if self.data_config.color:
                list_of_transforms.append(transforms.ColorJitter(brightness=0.4,
                                                                 contrast=0.4,
                                                                 saturation=0.4))
        return transforms.Compose(list_of_transforms)

    def _getTrainSet(self):

        train_transform = self._getTransformList(self.type_of_data_augmentation)

        if config.mimic.type == 'CRD':
            is_sample = True
        else:
            is_sample = False

        if self.data_config.train_data_type == 'img':
            train_set = torchvision.datasets.ImageFolder(self.trainFolder, train_transform)
        elif self.data_config.train_data_type == 'lmdb':
            train_set = lmdb_dataset.ImageFolder(self.trainFolder,
                                                 os.path.join(self.trainFolder, '..', 'train_datalist'),
                                                 train_transform,
                                                 patch_dataset=self.patch_dataset,
                                                 is_sample = is_sample
                                                 )
        self.train_num_examples = train_set.__len__()

        return train_set

    def _getTestSet(self):

        test_transform = self._getTransformList('validation')
        if self.data_config.val_data_type == 'img':
            test_set = torchvision.datasets.ImageFolder(self.testFolder, test_transform)
        elif self.data_config.val_data_type == 'lmdb':
            test_set = lmdb_dataset.ImageFolder(self.testFolder,
                                                os.path.join(self.testFolder, '..', 'val_datalist'),
                                                test_transform,
                                                is_sample=False)
            self.test_num_examples = test_set.__len__()
        return test_set

    def getTrainLoader(self, batch_size, shuffle=True):

        train_set = self._getTrainSet()

        train_loader = torch.utils.data.DataLoader(
            train_set, batch_size=batch_size, shuffle=shuffle,
            num_workers=self.num_workers, pin_memory=self.pin_memory,
            sampler=None, collate_fn=fast_collate, drop_last=True)

        return train_loader

    def getTestLoader(self, batch_size, shuffle=False):

        test_set = self._getTestSet()

        test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=batch_size, shuffle=shuffle,
            num_workers=self.num_workers, pin_memory=self.pin_memory, sampler=None,
            collate_fn=fast_collate)
        return test_loader

    def getTrainTestLoader(self, batch_size, train_shuffle=True, val_shuffle=False):

        train_loader = self.getTrainLoader(batch_size, train_shuffle)
        test_loader = self.getTestLoader(batch_size, val_shuffle)
        return train_loader, test_loader