import numpy as np
from PIL import Image
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data.sampler import BatchSampler
from functions import *


class TensorDatasetWrapper(Dataset):
    def __init__(self, tensordataset, transform=None):
        self.tensordataset = tensordataset
        self.data = tensordataset.tensors[0]
        self.targets = tensordataset.tensors[1]
        self.transform = transform
    
    def __getitem__(self, index):
        img1, label1 = self.data[index], self.targets[index].item()
        if self.transform is not None:
            img1 = Image.fromarray(img1.numpy(), mode='L')
            img1 = self.transform(img1)
        return img1, label1

    def __len__(self):
        return len(self.tensordataset)

class TripletDataset(Dataset):
    """
    Train: For each sample (anchor) randomly chooses a positive and negative samples
    Test: Creates fixed triplets for testing
    """

    def __init__(self, mnist_dataset, train=True, transform=None, use_target=True, triplets_index=None):
        self.mnist_dataset = mnist_dataset
        self.train = train
        self.transform = transform
        self.use_target = use_target
        self.triplets_index = triplets_index

        if self.train:
            if type(self.mnist_dataset) == TensorDataset:
                self.targets = self.mnist_dataset.tensors[1]
                self.data = self.mnist_dataset.tensors[0]
            else:
                self.targets = self.mnist_dataset.targets
                self.data = self.mnist_dataset.data

            self.labels_set = set(self.targets.numpy())
            self.label_to_indices = {label: np.where(self.targets.numpy() == label)[0]
                                    for label in self.labels_set}

        else:
            if type(self.mnist_dataset) == TensorDataset:
                self.targets = self.mnist_dataset.tensors[1]
                self.data = self.mnist_dataset.tensors[0]
            else:
                self.targets = self.mnist_dataset.targets
                self.data = self.mnist_dataset.data
            # generate fixed triplets for testing
            self.labels_set = set(self.targets.numpy())
            self.label_to_indices = {label: np.where(self.targets.numpy() == label)[0]
                                     for label in self.labels_set}

            if triplets_index is None:
                random_state = np.random.RandomState(29)

                triplets = [[i,
                            random_state.choice(self.label_to_indices[self.targets[i].item()]),
                            random_state.choice(self.label_to_indices[
                                                    np.random.choice(
                                                        list(self.labels_set - set([self.targets[i].item()]))
                                                    )
                                                ])
                            ]
                            for i in range(len(self.data))]
                self.triplets_index = triplets

    def __getitem__(self, index):
        if self.train and self.triplets_index is None:
            img1, pos_label = self.data[index], self.targets[index].item()
            positive_index = index
            while positive_index == index:
                positive_index = np.random.choice(self.label_to_indices[pos_label])
            negative_label = np.random.choice(list(self.labels_set - set([pos_label])))
            negative_index = np.random.choice(self.label_to_indices[negative_label])
            img2 = self.data[positive_index]
            img3 = self.data[negative_index]
        else:
            img1 = self.data[self.triplets_index[index][0]]
            img2 = self.data[self.triplets_index[index][1]]
            img3 = self.data[self.triplets_index[index][2]]
            pos_label = self.targets[self.triplets_index[index][0]]
            negative_label = self.targets[self.triplets_index[index][2]]


        if self.transform is not None:
            img1 = Image.fromarray(img1.numpy(), mode='L')
            img2 = Image.fromarray(img2.numpy(), mode='L')
            img3 = Image.fromarray(img3.numpy(), mode='L')
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
        if self.use_target:
            return (img1, img2, img3), (pos_label, negative_label)
        else:
            return (img1, img2, img3)

    def __len__(self):
        if self.triplets_index is not None:
            return len(self.triplets_index)
        return len(self.mnist_dataset)