import torch.utils.data as data
import torch
from os.path import isfile
import pickle
import numpy as np

class SSLDataset(data.Dataset):

    def __init__(self, base_dataset, sampling_type, sampling, balanced, check_exists):
        super(SSLDataset, self).__init__()

        self.base_dataset = base_dataset

        if sampling_type == 'load' and isfile(check_exists):
            with open(check_exists, 'rb') as f:
                self.unlabel = pickle.load(f)
        else:
            dataset_size = len(self.base_dataset.data)
            probs = torch.rand(dataset_size)

            if sampling_type == 'fixed':
                n_unlabel = int(dataset_size - sampling)
            else:
                n_unlabel = round((1 - sampling) * dataset_size)

            if balanced:
                unique_targets = torch.unique(self.base_dataset.targets)
                samples_per_targets = round((dataset_size - n_unlabel) / len(unique_targets))
                preserved_labels = np.array([])

                for target in unique_targets:
                    target_mask = self.base_dataset.targets == target
                    target_indices = target_mask.nonzero().squeeze()
                    target_sampling = np.random.choice(len(target_indices), samples_per_targets, replace=False)

                    preserved_labels = np.concatenate((preserved_labels,
                                                       target_indices[target_sampling].numpy()))

                probs[preserved_labels] = 0

            self.unlabel = torch.multinomial(probs, n_unlabel)

            with open(check_exists, 'wb') as f:
                pickle.dump(self.unlabel, f)

    def __getitem__(self, index):
        sample, target = self.base_dataset.__getitem__(index)

        if index in self.unlabel:
            target = 999 # torch.tensor(999, dtype=torch.int64)

        return sample, target

    def __len__(self):
        return self.base_dataset.__len__()
