# Modified from:
# https://github.com/wzekai99/DM-Improves-AT/blob/main/core/data/semisup.py#L8

import numpy as np
import torch


class ControlRatioSampler(torch.utils.data.Sampler):
    """
    Balanced sampling from the original and augmented data.
    """
    def __init__(self, ori_inds, aug_inds, batch_size, aug_fraction=0.5, num_batches=None):
        if aug_fraction is None or aug_fraction < 0:
            self.ori_inds = ori_inds + aug_inds
            aug_fraction = 0.0
        else:
            self.ori_inds = ori_inds
            self.aug_inds = aug_inds

        self.batch_size = batch_size
        aug_batch_size = int(batch_size * aug_fraction)
        self.ori_batch_size = batch_size - aug_batch_size

        print("\nInitiate ControlRatioSampler with ori_batch_size={}, aug_batch_size={}\n".format(self.ori_batch_size, aug_batch_size))

        if num_batches is not None:
            self.num_batches = num_batches
        else:
            self.num_batches = int(np.ceil(len(self.ori_inds) / self.ori_batch_size))
        super().__init__(None)

    def __iter__(self):
        batch_counter = 0
        while batch_counter < self.num_batches:
            ori_inds_shuffled = [self.ori_inds[i]
                                 for i in torch.randperm(len(self.ori_inds))]
            for sup_k in range(0, len(self.ori_inds), self.ori_batch_size):
                if batch_counter == self.num_batches:
                    break
                batch = ori_inds_shuffled[sup_k:(sup_k + self.ori_batch_size)]
                if self.ori_batch_size < self.batch_size:
                    batch.extend([self.aug_inds[i] for i in torch.randint(high=len(self.aug_inds), 
                                                                            size=(self.batch_size - len(batch),), 
                                                                            dtype=torch.int64)])
                np.random.shuffle(batch)
                yield batch
                batch_counter += 1

    def __len__(self):
        return self.num_batches