# From https://github.com/wzekai99/DM-Improves-AT/blob/0831ea493a6210047a84538a9a432a9b4727d78c/core/data/semisup.py
import numpy as np
import torch
from torch.utils.data import Dataset, Sampler


class SemiSupDataset(Dataset):
    """
    A dataset with auxiliary pseudo-labeled data.
    """

    def __init__(self, base: Dataset, aux_data_filename=None, train=False):

        self.dataset = base
        self.dataset_size = len(self.dataset)
        self.classes = self.dataset.classes
        self.transform = self.dataset.transform
        self.train = train

        if self.train:
            self.sup_indices = list(range(len(self.targets)))
            self.unsup_indices = []

            print('Loading data from %s' % aux_data_filename)
            aux = np.load(aux_data_filename)
            aux_data = aux['image']
            aux_targets = aux['label']
            print(aux_data.shape, aux_targets.shape)

            orig_len = len(self.data)

            self.data = np.concatenate((self.data, aux_data), axis=0)
            self.targets.extend(aux_targets)

            self.unsup_indices.extend(range(orig_len, orig_len + len(aux_data)))
        else:
            self.sup_indices = list(range(len(self.targets)))
            self.unsup_indices = []

    @property
    def data(self):
        return self.dataset.data

    @data.setter
    def data(self, value):
        self.dataset.data = value

    @property
    def targets(self):
        return self.dataset.targets

    @targets.setter
    def targets(self, value):
        self.dataset.targets = value

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        self.dataset.labels = self.targets
        return self.dataset[item]


class SemiSupSampler(Sampler):
    """
    Balanced sampling from the labeled and unlabeled data.
    """
    def __init__(self, sup_inds, unsup_inds, batch_size, unsup_fraction=0.5, num_batches=None):
        if unsup_fraction is None or unsup_fraction < 0:
            self.sup_inds = sup_inds + unsup_inds
            unsup_fraction = 0.0
        else:
            self.sup_inds = sup_inds
            self.unsup_inds = unsup_inds

        self.batch_size = batch_size
        unsup_batch_size = int(batch_size * unsup_fraction)
        self.sup_batch_size = batch_size - unsup_batch_size

        if num_batches is not None:
            self.num_batches = num_batches
        else:
            self.num_batches = int(np.ceil(len(self.sup_inds) / self.sup_batch_size))
        super().__init__(None)

    def __iter__(self):
        batch_counter = 0
        while batch_counter < self.num_batches:
            if self.sup_batch_size != 0:
                sup_inds_shuffled = [self.sup_inds[i]
                                     for i in torch.randperm(len(self.sup_inds))]
                for sup_k in range(0, len(self.sup_inds), self.sup_batch_size):
                    if batch_counter == self.num_batches:
                        break
                    batch = sup_inds_shuffled[sup_k:(sup_k + self.sup_batch_size)]
                    if self.sup_batch_size < self.batch_size:
                        batch.extend([self.unsup_inds[i] for i in torch.randint(high=len(self.unsup_inds),
                                                                                size=(self.batch_size - len(batch),),
                                                                                dtype=torch.int64)])
                    np.random.shuffle(batch)
                    yield batch
                    batch_counter += 1
            else:  # unsupported only
                batch = [self.unsup_inds[i] for i in torch.randint(high=len(self.unsup_inds), size=(self.batch_size,),
                                                                   dtype=torch.int64)]
                yield batch
                batch_counter += 1

    def __len__(self):
        return self.num_batches