import math
import torch
from torch.utils.data.sampler import RandomSampler
import numpy as np
import random
import pandas as pd


class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):
    """
    iterate over tasks and provide a random batch per task in each mini-batch
    """

    def __init__(self, dataset, batch_size, args, test_mode=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = len(dataset.datasets)
        self.largest_dataset_size = max(
            [len(cur_dataset) for cur_dataset in dataset.datasets]
        )
        self.datasets_size_p = (
            np.array([len(cur_dataset) for cur_dataset in dataset.datasets]) ** 0.5
        )
        self.max_size_p = max(self.datasets_size_p)
        self.test_mode = test_mode
        self.args = args
        print("sizes:", self.datasets_size_p)

    def __len__(self):
        x = sum([1 for _ in self.__iter__()])
        return x
        # return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)

    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)

        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size
        # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
        epoch_samples = self.largest_dataset_size * self.number_of_datasets

        final_samples_list = []  # this is a list of indexes from the combined dataset
        for _ in range(0, epoch_samples, step):
            for i in sorted(
                range(self.number_of_datasets), key=lambda _: random.random()
            ):
                p = max(0.1, self.datasets_size_p[i] / self.max_size_p)
                if not self.test_mode and random.random() > p:
                    # print((i, p,self.datasets_size_p[i],self.max_size_p,),end=",")
                    continue
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        # got to the end of iterator - restart the iterator and continue to get samples
                        # until reaching "epoch_samples"
                        sampler_iterators[i] = samplers_list[i].__iter__()
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                final_samples_list.extend(cur_samples)

        return iter(final_samples_list)
