import random
import numpy as np
from torch.utils.data.sampler import Sampler

class RandomCycleIter: # only for train
    
    def __init__ (self, data):
        self.data = list(data)
        self.len = len(self.data)
        self.idx = self.len - 1
        
    def __iter__ (self):
        return self
    
    def __next__ (self):
        self.idx += 1
        if self.idx == self.len:
            self.idx = 0
            random.shuffle(self.data)
        return self.data[self.idx]
    
def class_aware_sample_generator (cls_iter, data_iter, n, num_samples_cls=1):

    i = 0
    j = 0
    while i < n:
        if j >= num_samples_cls:
            j = 0
        if j == 0:
            data_tuple = next(zip(*[data_iter[next(cls_iter)]]*num_samples_cls))
            yield data_tuple[j]
        else:
            yield data_tuple[j]
        i += 1
        j += 1

class ClassAwareSampler(Sampler):
    
    def __init__(self, data_source, num_samples_per_category=4, soft_file=False):
        if soft_file:
            labels = list(data_source.get_gt_logits())
        else:
            labels = list(data_source.get_gt_labels())

        unique_cats = np.unique(labels)
        num_classes = len(unique_cats)
        # num_classes = len(data_source.CLASSES)
        self.class_iter = RandomCycleIter(unique_cats)
        cls_data, data_iter, len_of_each_cat = {}, {}, []
        for cat in unique_cats:
            cls_data[cat] = list()
        for i, label in enumerate(labels):
            cls_data[int(label)].append(i)
        for k, v in cls_data.items():
            data_iter[k] = RandomCycleIter(v)
            len_of_each_cat.append(len(v))
        self.num_samples = max(len_of_each_cat) * num_classes
        self.num_samples_per_category = num_samples_per_category
        self.data_iter = data_iter

    def __iter__ (self):
        return class_aware_sample_generator(
            self.class_iter, self.data_iter, self.num_samples, 
            self.num_samples_per_category)
    
    def __len__ (self):
        return self.num_samples