import numpy as np
import torch
from torchvision import datasets


class CIFAR100SubSet(torch.utils.data.Dataset):  # TODO: use torch.utils.data.Dataset with batch sampling
    def __init__(self, root, train=True, transform=None, download=True, returns="all", num_sample=None):
        """CIFAR-10 dataset with index to extract a mini-batch based on given batch indices
        Useful for VFL training

        Args:
            root: data root
            data_idx: to specify the data for a particular client site.
                If index provided, extract subset, otherwise use the whole set
            train: whether to use the training or validation split (default: True)
            transform: image transforms
            download: whether to download the data (default: False)
            returns: specify which data the client has
        Returns:
            A PyTorch dataset
        """
        self.root = root
        self.train = train
        self.transform = transform
        self.download = download
        self.returns = returns
        self.data, self.target = self.__build_cifar_subset__(num_sample=num_sample)
        #print(self.data.shape)


    def __build_cifar_subset__(self, num_sample):
        # if index provided, extract subset, otherwise use the whole set
        cifar_dataobj = datasets.CIFAR100(self.root, self.train, self.transform, download=self.download)
        data = cifar_dataobj.data
        target = np.array(cifar_dataobj.targets)

        if num_sample != None:
            # sort labels
            idxs = np.arange(len(target))
            labels = np.array(target)
            idxs_labels = np.vstack((idxs, labels))
            idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
            idxs = idxs_labels[0, :]

            samples_per_class = int(len(target)/100)
            sample_idx = []
            for c in range(100):
                sample_idx_c = np.random.choice(samples_per_class, int(num_sample/100))
                sample_idx_c += c*samples_per_class
                sample_idx.append(sample_idx_c)

            sample_idx = np.concatenate(sample_idx)
            sample_idx = idxs[sample_idx]

            data = data[sample_idx]
            target = target[sample_idx]

        return data, target

    def __getitem__(self, index):
        data, target = self.data[index], self.target[index]
        if self.transform is not None:
            data = self.transform(data)
        return data, target

    def __len__(self):
        return len(self.data)


