import random

import numpy as np
import torch
from torch.utils.data import Sampler


class RandomCycleIter:

    def __init__(self, data_list, test_mode=False):
        self.data_list = list(data_list)
        self.length = len(self.data_list)
        self.i = self.length - 1
        self.test_mode = test_mode

    def __iter__(self):
        return self

    def __next__(self):
        self.i += 1
        if self.i == self.length:
            self.i = 0
            if not self.test_mode:
                random.shuffle(self.data_list)
        return self.data_list[self.i]


def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1):
    i = 0
    j = 0
    # sampled_data_list = []
    while i < n:
        if j >= num_samples_cls:
            j = 0
        if j == 0:
            temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]] * num_samples_cls))
            # sampled_data_list.append(temp_tuple[j])
            yield temp_tuple[j]
        else:
            # sampled_data_list.append(temp_tuple[j])
            yield temp_tuple[j]
        i += 1
        j += 1
    # return sampled_data_list


class ClassAwareSampler(Sampler):

    def __init__(self, data_source, num_samples_cls=3, reduce=4):
        random.seed(0)
        torch.manual_seed(0)
        num_classes = len(np.unique(data_source.CLASSES))

        self.class_iter = RandomCycleIter(range(num_classes))

        self.cls_data_list = data_source.cls_data_list

        self.num_classes = num_classes
        self.data_iter_list = [RandomCycleIter(x) for x in self.cls_data_list]
        self.num_samples = int(max([len(x) for x in self.cls_data_list]) * len(self.cls_data_list) / reduce)
        self.num_samples_cls = num_samples_cls
        print(">>> Class Aware Sampler Built! Class number: {}, reduce {}".format(num_classes, reduce))

        self.sampled_data_list = class_aware_sample_generator(self.class_iter, self.data_iter_list,
                                                              self.num_samples, self.num_samples_cls)

    def __iter__(self):
        # random.shuffle(self.sampled_data_list)
        # return iter(self.sampled_data_list)
        return class_aware_sample_generator(self.class_iter, self.data_iter_list,
                                            self.num_samples, self.num_samples_cls)

    def __len__(self):
        return self.num_samples
