from torch.utils.data import Dataset, Sampler, DataLoader
import random

class StratifiedBatchSampler(Sampler):
    def __init__(self, data_source: Dataset, seed=0, rank=0):
        self.data_source = data_source
        self.task_to_indices = self.data_source.task_to_indices
        self.task_ids = list(self.task_to_indices.keys())
        self.batch_size = self.data_source.batch_size

        self.samples_per_task = self.batch_size // len(self.task_to_indices.keys())
        self.seed = seed
        self.rank = rank
        self.all_indices = [idx for indices in self.task_to_indices.values() for idx in indices]

    def __iter__(self):
        random.seed(self.seed + self.rank)
        task_pools = {
            task: random.sample(indices, len(indices))
            for task, indices in self.task_to_indices.items()
        }

        while True:
            batch = []
            # Sample at least `samples_per_task` from each task
            for task in self.task_ids:
                pool = task_pools[task]
                if len(pool) < self.samples_per_task:
                    task_pools[task] = random.sample(self.all_indices, len(self.all_indices))
                    pool = task_pools[task]
                    if len(pool) < self.samples_per_task:
                        raise ValueError(f"Not enough samples in task {task} to satisfy `samples_per_task`.")
                batch.extend([pool.pop() for _ in range(self.samples_per_task)])

            # Randomly sample the remaining samples to fill the batch
            remaining_samples = self.batch_size - len(batch)
            if remaining_samples > 0:
                batch.extend(random.sample(self.all_indices, remaining_samples))

            yield batch

    def __len__(self):
        min_size = min(len(indices) for indices in self.task_to_indices.values())
        print(f"Min size: {min_size}")
        print(f"Samples per task: {self.samples_per_task}")
        return min_size // self.samples_per_task