import numpy as np
from PIL import Image

from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
    
    
class BregMNIST(Dataset):
    """
    Train: For each sample creates randomly a pair of images
    Test: Creates fixed image pairs for testing
    
    The target labels are generated from the original image pair's
    labels using the divergence_fn
    """

    def __init__(self, mnist_dataset, divergence_fn):
        self.mnist_dataset = mnist_dataset
        self.div_fn = divergence_fn

        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.targets = self.mnist_dataset.targets
            self.data = self.mnist_dataset.data
        
        else:
            self.targets = self.mnist_dataset.targets
            self.data = self.mnist_dataset.data
            
            random_state = np.random.RandomState(100)
            all_index2 = random_state.permutation(len(self.data))
    
            pair_idx = [(i, all_index2[i]) for i in range(len(self.data))]
            pair_labels = [(self.targets[pair[0]].item(), self.targets[pair[1]].item())
                           for pair in pair_idx]
            div_target = [self.div_fn(*pair) for pair in pair_labels]
            
            self.test_pairs = pair_idx
            self.div_target = div_target

    def __getitem__(self, index):
        if self.train:
            index2 = np.random.choice(len(self.data))

            img1, label1 = self.data[index], self.targets[index].item()
            img2, label2 = self.data[index2], self.targets[index2].item()
            target = self.div_fn(label1, label2)

        else:
            index1, index2 = self.test_pairs[index][0], self.test_pairs[index][1]
            img1, img2 = self.data[index1], self.data[index2]
            target = self.div_target[index]
        
        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return img1, img2, target

    def __len__(self):
        return len(self.mnist_dataset)
