import os
import numpy as np
import torch
import torch.utils.data as data_utils
import torchvision
from sc2image.dataset import StarCraftMNIST

def split_image_into_patches(image, n_device):

    B, C, H, W = image.shape
    n_patch_per_dim = int(np.sqrt(n_device))
    assert H % n_patch_per_dim == 0, 'd_img should be divisible by n_device'
    patch_size = H // n_patch_per_dim

    # save for later
    #image = image[:, :n_patch_per_dim * patch_size, :n_patch_per_dim * patch_size]
    image = image.view(-1, C, n_patch_per_dim, patch_size, n_patch_per_dim, patch_size)
    image = image.permute(0, 2, 4, 1, 3, 5).contiguous().view(B, -1, C, patch_size, patch_size)

    return image

class MNISTPatch(data_utils.Dataset):
    def __init__(self,
                 n_device,
                 root='./data',
                 train=True,
                 transform=None,
                 download=True):

        self.n_device = n_device
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train
        self.download = download

        self.images, self.labels = self._get_data()


    def _get_data(self):
        # =================================================================================== #
        #                         Load MNIST                                                  #
        # =================================================================================== #
        dataset = torchvision.datasets.MNIST(self.root, train=self.train,
                                 download=self.download,
                                 transform=torchvision.transforms.ToTensor())
        loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)

        # =================================================================================== #
        #                         Get patch                                                   #
        # =================================================================================== #
        images,labels = next(iter(loader))
        images_patch = split_image_into_patches(images, self.n_device)


        return images_patch, labels

    def __len__(self):

        return len(self.labels)

    def __getitem__(self, index):

        x = self.images[index]
        y = self.labels[index]

        if self.transform is not None:
            x = self.transform(x)

        return x,y

class FashionMNISTPatch(data_utils.Dataset):
    def __init__(self,
                 n_device,
                 root='./data',
                 train=True,
                 transform=None,
                 download=True):

        self.n_device = n_device
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train
        self.download = download

        self.images, self.labels = self._get_data()


    def _get_data(self):
        # =================================================================================== #
        #                         Load MNIST                                                  #
        # =================================================================================== #
        dataset = torchvision.datasets.FashionMNIST(self.root, train=self.train,
                                 download=self.download,
                                 transform=torchvision.transforms.ToTensor())
        loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)

        # =================================================================================== #
        #                         Get patch                                                   #
        # =================================================================================== #
        images,labels = next(iter(loader))
        images_patch = split_image_into_patches(images, self.n_device)


        return images_patch, labels

    def __len__(self):

        return len(self.labels)

    def __getitem__(self, index):

        x = self.images[index]
        y = self.labels[index]

        if self.transform is not None:
            x = self.transform(x)

        return x,y

class StarCraftMNISTPatch(data_utils.Dataset):
    def __init__(self,
                 n_device,
                 root='./data',
                 train=True,
                 transform=None,
                 download=True):

        self.n_device = n_device
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train
        self.download = download

        self.images, self.labels = self._get_data()


    def _get_data(self):
        # =================================================================================== #
        #                         Load MNIST                                                  #
        # =================================================================================== #
        dataset = StarCraftMNIST(self.root, train=self.train,
                                 download=self.download,
                                 transform=torchvision.transforms.ToTensor())
        loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)

        # =================================================================================== #
        #                         Get patch                                                   #
        # =================================================================================== #
        images,labels = next(iter(loader))
        images_patch = split_image_into_patches(images, self.n_device)


        return images_patch, labels

    def __len__(self):

        return len(self.labels)

    def __getitem__(self, index):

        x = self.images[index]
        y = self.labels[index]

        if self.transform is not None:
            x = self.transform(x)

        return x,y

class CIFAR10Patch(data_utils.Dataset):
    def __init__(self,
                 n_device,
                 root='./data',
                 train=True,
                 transform=None,
                 download=True):

        self.n_device = n_device
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train
        self.download = download

        self.images, self.labels = self._get_data()


    def _get_data(self):
        # =================================================================================== #
        #                         Load MNIST                                                  #
        # =================================================================================== #
        dataset = torchvision.datasets.CIFAR10(self.root, train=self.train,
                                 download=self.download,
                                 transform=torchvision.transforms.ToTensor())
        loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)

        # =================================================================================== #
        #                         Get patch                                                   #
        # =================================================================================== #
        images,labels = next(iter(loader))
        images_patch = split_image_into_patches(images, self.n_device)


        return images_patch, labels

    def __len__(self):

        return len(self.labels)

    def __getitem__(self, index):

        x = self.images[index]
        y = self.labels[index]

        if self.transform is not None:
            x = self.transform(x)

        return x,y