import random
import torch
import numpy as np
from torch.utils.data import DistributedSampler as _DistributedSampler

from .classaware_sampler import RandomCycleIter, class_aware_sample_generator

class DistributedClassAwareSampler(_DistributedSampler):

    def __init__(self,
                 data_source,
                 num_replicas=None,
                 rank=None,
                 shuffle=True,
                 round_up=True,
                 num_samples_per_category=4, 
                 soft_file=False):
        super().__init__(data_source, num_replicas=num_replicas, rank=rank)

        if soft_file:
            labels = list(data_source.get_gt_logits())
        else:
            labels = list(data_source.get_gt_labels())

        unique_cats = np.unique(labels)
        num_classes = len(unique_cats)
        self.class_iter = RandomCycleIter(unique_cats)
        cls_data, data_iter, len_of_each_cat = {}, {}, []
        for cat in unique_cats:
            cls_data[cat] = list()
        for i, label in enumerate(labels):
            cls_data[int(label)].append(i)
        for k, v in cls_data.items():
            data_iter[k] = RandomCycleIter(v)
            len_of_each_cat.append(len(v))
        self.num_samples = max(len_of_each_cat) * num_classes
        self.num_samples_per_category = num_samples_per_category
        self.data_iter = data_iter

        self.shuffle = shuffle
        self.round_up = round_up
        if self.round_up:
            self.total_size = self.num_samples * self.num_replicas
        else:
            self.total_size = self.num_samples #len(self.dataset)

    def __iter__(self):
        return class_aware_sample_generator(
            self.class_iter, self.data_iter, self.num_samples, 
            self.num_samples_per_category)
