import os

import numpy as np
from PIL import Image
from torch.utils.data import Dataset

from .config import DATA_PATHS


class DigitsDataset(Dataset):
    all_domains = ['MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST_M']
    resorted_domains = {
        0: ['MNIST',    'SVHN', 'USPS', 'SynthDigits', 'MNIST_M'],
        1: ['SVHN',     'USPS', 'SynthDigits', 'MNIST_M', 'MNIST'],
        2: ['USPS',     'SynthDigits', 'MNIST_M', 'MNIST', 'SVHN'],
        3: ['SynthDigits', 'MNIST_M', 'MNIST', 'SVHN', 'USPS'],
        4: ['MNIST_M',  'MNIST', 'SVHN', 'USPS', 'SynthDigits'],
    }
    num_classes = 10  # may not be correct

    def __init__(self, domain, percent=0.1, filename=None, train=True, transform=None):
        data_path = os.path.join(DATA_PATHS["Digits"], domain)
        if filename is None:
            if train:
                if percent >= 0.1:
                    for part in range(int(percent*10)):
                        if part == 0:
                            self.images, self.labels = np.load(
                                os.path.join(data_path,
                                             'partitions/train_part{}.pkl'.format(part)),
                                allow_pickle=True)
                        else:
                            images, labels = np.load(
                                os.path.join(data_path,
                                             'partitions/train_part{}.pkl'.format(part)),
                                allow_pickle=True)
                            self.images = np.concatenate([self.images,images], axis=0)
                            self.labels = np.concatenate([self.labels,labels], axis=0)
                else:
                    self.images, self.labels = np.load(
                        os.path.join(data_path, 'partitions/train_part0.pkl'),
                        allow_pickle=True)
                    data_len = int(self.images.shape[0] * percent*10)
                    self.images = self.images[:data_len]
                    self.labels = self.labels[:data_len]
            else:
                self.images, self.labels = np.load(os.path.join(data_path, 'test.pkl'),
                                                   allow_pickle=True)
        else:
            self.images, self.labels = np.load(os.path.join(data_path, filename),
                                               allow_pickle=True)

        self.transform = transform
        self.channels = 3 if domain in ['SVHN', 'SynthDigits', 'MNIST_M'] else 1
        self.labels = self.labels.astype(np.long).squeeze()
        self.classes = np.unique(self.labels)

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.channels == 1:
            image = Image.fromarray(image, mode='L')
        elif self.channels == 3:
            image = Image.fromarray(image, mode='RGB')
        else:
            raise ValueError("{} channel is not allowed.".format(self.channels))

        if self.transform is not None:
            image = self.transform(image)

        return image, label
