import os
import random
import torch
import numpy as np
from PIL import Image
from torchvision import transforms as transforms_lib
from torchvision.datasets import CIFAR10, MNIST, FashionMNIST, SVHN
from torch.utils.data import random_split, Subset
from utils import *
from tqdm import tqdm


# Setup a logger
logger = setup_logger(__name__)

# Expected format of the the Custom Image Dataset follows Imagefolder format
# root/class1/xxx.png
# root/class1/xxy.png
# root/class2/xxx.png
# root/class2/xxy.png
class CustomImageDataset(torch.utils.data.Dataset):
    # Initialize the transforms and the directory containing the images
    def __init__(self, image_dir, transform = None, target_transform = None, pre_shuffle = None):
        self.images = []
        self.labels = []
        self.image_dir = image_dir
        self.transform = transform
        self.target_transform = target_transform

        # Setup the dataset
        self.setup(pre_shuffle = pre_shuffle)

    # Return the number of samples in the dataset
    def __len__(self):
        return len(self.labels)

    # Return a sample from the dataset at the given index(idx)
    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert('RGB')
        
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

    # Get the image paths and their respective labels based on folder structure
    def setup(self, pre_shuffle=False):
        logger.debug('Setting up the dataset')
        list_of_classes = sorted(next(os.walk(self.image_dir))[1])

        for class_index, class_id in enumerate(list_of_classes):
            full_path = os.path.join(self.image_dir, class_id)

            images = sorted([os.path.join(full_path, file) for file in os.listdir(full_path)])
            
            num_of_examples = len(images)
            labels = [class_index] * num_of_examples

            self.images = self.images + images
            self.labels = self.labels + labels


        # NOTE Pre-shuffle originally implemented for use during use of custom samplers
        if pre_shuffle:
            logger.debug('Applying pre-shuffling to the dataset')
            # For post-process testing
            rand_idx = random.randint(0, len(self.labels) - 2)
            rand_img = self.images[rand_idx]

            shuffled_images_with_labels = []

            for (image, label) in zip(self.images, self.labels):
                shuffled_images_with_labels.append([image, label])
            
            random.shuffle(shuffled_images_with_labels)
            self.images = []
            self.labels = []
            for i, image_and_label in enumerate(shuffled_images_with_labels):
                self.images.append(image_and_label[0])
                self.labels.append(image_and_label[1])

            assert rand_img is not self.images[rand_idx], "Expected two images at same index pre and post-shuffle to not match, got match."

# Dataset download: https://github.com/fastai/imagenette
class Imagenette(CustomImageDataset):
    def __init__(self, image_dir, pre_shuffle = False):
        
        transform = transforms_lib.Compose(
            [
                transforms_lib.Resize(256),
                transforms_lib.CenterCrop(224),
                transforms_lib.ToTensor(),
                transforms_lib.Normalize(
                    mean = (0.485, 0.456, 0.406), # Taken from torchvision models
                    std = (0.229, 0.224, 0.225), # Taken from torchvision models
                )
            ]
        )
        CustomImageDataset.__init__(self, image_dir, transform, None, pre_shuffle = pre_shuffle)

# Dataset download: https://github.com/fastai/imagenette
class Imagewoof(CustomImageDataset):
    def __init__(self, image_dir, pre_shuffle = False):
        
        transform = transforms_lib.Compose(
            [
                transforms_lib.Resize(256),
                transforms_lib.CenterCrop(224),
                transforms_lib.ToTensor(),
                transforms_lib.Normalize(
                    mean = (0.485, 0.456, 0.406), # Taken from torchvision models
                    std = (0.229, 0.224, 0.225), # Taken from torchvision models
                )
            ]
        )
        CustomImageDataset.__init__(self, image_dir, transform, None, pre_shuffle = pre_shuffle)

class Imagenet(CustomImageDataset):
    def __init__(self, image_dir, pre_shuffle = False):
        
        transform = transforms_lib.Compose(
            [
                transforms_lib.Resize(256),
                transforms_lib.CenterCrop(224),
                transforms_lib.ToTensor(),
                transforms_lib.Normalize(
                    mean = (0.485, 0.456, 0.406), # Taken from torchvision models
                    std = (0.229, 0.224, 0.225), # Taken from torchvision models
                )
            ]
        )
        CustomImageDataset.__init__(self, image_dir, transform, None, pre_shuffle = pre_shuffle)

class CustomCIFAR10():
    def __init__(self, image_dir, pre_shuffle = False):
        self.transform = transforms_lib.Compose(
            [
                transforms_lib.ToTensor(),
                transforms_lib.Normalize(
                    mean = (0.49139968, 0.48215827 ,0.44653124), 
                    std = (0.24703233, 0.24348505, 0.26158768)
                ),
            ]
        )
 
        self.dims = (3, 32, 32)
        self.image_dir = image_dir
        self.pre_shuffle = pre_shuffle

    def prepare_data(self):
        # download
        CIFAR10(self.image_dir, train=True, download=True)
        CIFAR10(self.image_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            cifar10_full = CIFAR10(self.image_dir, train = True, download = True, transform = self.transform)
            dataset_size = len(cifar10_full)
            indicies = list(range(dataset_size))
            split = int(np.floor(0.1 * dataset_size))
            if self.pre_shuffle:
                np.random.seed(42)
                np.random.shuffle(indicies)
            train_indicies, val_indicies = indicies[split:], indicies[:split]
            
            self.train = Subset(cifar10_full, train_indicies)
            self.train.labels = cifar10_full.targets[split:]

            self.val = Subset(cifar10_full,val_indicies)
            self.val.labels = cifar10_full.targets[:split]
            
        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.test = CIFAR10(self.image_dir, train=False, transform=self.transform)
            self.test.labels = self.test.targets

class CustomFashionMNIST():
    def __init__(self, image_dir, pre_shuffle = False):
        self.transform = transforms_lib.Compose(
            [
                transforms_lib.ToTensor(),
                # TODO - These need checking
                transforms_lib.Lambda(lambda x: x.repeat(3, 1, 1)), # Change to color images
                transforms_lib.Normalize(
                    mean = (0.1307,), 
                    std = (0.3081,)
                ),
            ]
        )
 
        self.dims = (1, 28, 28)
        self.image_dir = image_dir
        self.pre_shuffle = pre_shuffle

    def prepare_data(self):
        # download
        FashionMNIST(self.image_dir, train=True, download=True)
        FashionMNIST(self.image_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            fashion_mnist_full = FashionMNIST(self.image_dir, train = True, download = True, transform = self.transform)
            dataset_size = len(fashion_mnist_full)
            indicies = list(range(dataset_size))
            split = int(np.floor(0.1 * dataset_size))
            if self.pre_shuffle:
                np.random.seed(42)
                np.random.shuffle(indicies)
            train_indicies, val_indicies = indicies[split:], indicies[:split]
            
            self.train = Subset(fashion_mnist_full, train_indicies)
            self.train.labels = fashion_mnist_full.targets[split:]

            self.val = Subset(fashion_mnist_full,val_indicies)
            self.val.labels = fashion_mnist_full.targets[:split]
            
        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.test = FashionMNIST(self.image_dir, train=False, transform=self.transform)
            self.test.labels = self.test.targets

class CustomMNIST():
    def __init__(self, image_dir, pre_shuffle = False):
        self.transform = transforms_lib.Compose(
            [
                transforms_lib.ToTensor(),
                # TODO - These need checking
                transforms_lib.Lambda(lambda x: x.repeat(3, 1, 1)), # Change to color images
                transforms_lib.Normalize(
                    mean = (0.1307,), 
                    std = (0.3081,)
                ),
            ]
        )
 
        self.dims = (1, 28, 28)
        self.image_dir = image_dir
        self.pre_shuffle = pre_shuffle

    def prepare_data(self):
        # download
        MNIST(self.image_dir, train=True, download=True)
        MNIST(self.image_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.image_dir, train = True, download = True, transform = self.transform)
            dataset_size = len(mnist_full)
            indicies = list(range(dataset_size))
            split = int(np.floor(0.1 * dataset_size))
            if self.pre_shuffle:
                np.random.seed(42)
                np.random.shuffle(indicies)
            train_indicies, val_indicies = indicies[split:], indicies[:split]
            
            self.train = Subset(mnist_full, train_indicies)
            self.train.labels = mnist_full.targets[split:]

            self.val = Subset(mnist_full,val_indicies)
            self.val.labels = mnist_full.targets[:split]
            
        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.test = MNIST(self.image_dir, train=False, transform=self.transform)
            self.test.labels = self.test.targets

class CustomSVHN():
    def __init__(self, image_dir, pre_shuffle = False):
        self.transform = transforms_lib.Compose(
            [
                transforms_lib.ToTensor(),
                # TODO - These need checking
                # transforms_lib.Lambda(lambda x: x.repeat(3, 1, 1)), # Change to color images
                transforms_lib.Normalize(
                    mean = (0.4376821, 0.4437697, 0.47280442), 
                    std = (0.19803012, 0.20101562, 0.19703614)
                ),
            ]
        )
 
        self.dims = (3, 32, 32)
        self.image_dir = image_dir
        self.pre_shuffle = pre_shuffle

    def prepare_data(self):
        # download
        SVHN(self.image_dir, train=True, download=True)
        SVHN(self.image_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            svhn_full = SVHN(self.image_dir, split = 'train', download = True, transform = self.transform)
            dataset_size = len(svhn_full)
            indicies = list(range(dataset_size))
            split = int(np.floor(0.1 * dataset_size))
            if self.pre_shuffle:
                np.random.seed(42)
                np.random.shuffle(indicies)
            train_indicies, val_indicies = indicies[split:], indicies[:split]
            
            self.train = Subset(svhn_full, train_indicies)
            self.train.labels = svhn_full.labels[split:]

            self.val = Subset(svhn_full,val_indicies)
            self.val.labels = svhn_full.labels[:split]
            
        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.test = SVHN(self.image_dir, split='test', download = True, transform=self.transform)
            self.test.labels = self.test.labels

if __name__ == '__main__':
    print('Testing Datasets')

    dataset = Imagenette(image_dir = '/data/progressive_data_dropout/imagenette/train')
    dataloader = torch.utils.data.DataLoader(dataset, shuffle = False, batch_size = 32, num_workers = 32)

    for batch_id, (images, labels) in tqdm(enumerate(dataloader)):
        print(images)
        print(labels)
        break

    print('Finish Looping through dataset')