
import torch
import numpy as np
import itertools
from PIL import Image
import math


def produce_idx(data, labels):
    combin = list(itertools.combinations([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 2))

    for lab in combin:
        idx = []
        print(lab)
        for i in range(len(data)):
            if labels[i] == lab[0] or labels[i] == lab[1]:
                idx.append(i)

        idx = np.array(idx)
        file = 'data/MNIST/indices/idx' + str(lab[0]) + str(lab[1])

        np.save(file, idx)


def load_data(idx1, idx2, seed='No'):
    train = torch.load('data/MNIST/processed/training.pt')
    # test = torch.load('data/MNIST/processed/test.pt')

    zero = []

    data = train[0]
    labels = train[1]

    filename = 'data/MNIST/indices/idx' + str(idx1) + str(idx2) + '.npy'

    idx = np.load(filename)

    data = data[idx]
    labels = labels[idx]

    for i in range(len(labels)):
        if labels[i] == idx1:
            labels[i] = 0
        else:
            labels[i] = 1

    if seed == 'No':
        # print("With same data")
        return data, labels
    else:
        print("With different data")
        length = len(data)
        third = math.floor(length/3)

        data = data[seed*third:(seed+1)*third]
        labels = labels[seed*third:(seed+1)*third]

        return data, labels


class two_MNIST():
    training_file = 'training.pt'
    test_file = 'test.pt'
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

    def __init__(self, transform, idx1, idx2, seed):
        # super(two_MNIST, self).__init__(root)
        self.transform = transform
        self.idx1, self.idx2 = idx1, idx2
        # self.target_transform = target_transform
        # self.train = train  # training set or test set

        '''if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file'''

        self.data, self.targets = load_data(idx1, idx2, seed)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        '''if self.target_transform is not None:
            target = self.target_transform(target)'''

        return img, target

    def __len__(self):
        return len(self.data)
