from multiguide.training.bucket_batch_ddp_sampler import DistributedBucketBatchSampler

class CurriculumBucketBatchSampler(DistributedBucketBatchSampler):
    """
    Extends DistributedBucketBatchSampler to implement curriculum learning
    by gradually introducing shorter sequences during training.
    """
    def __init__(
        self,
        dataset,
        min_perc_start: int = 60,
        epochs_to_include_all: int = 10,
        **kwargs
    ):
        self.min_perc_start = min_perc_start
        self.epochs_to_include_all = epochs_to_include_all
        super().__init__(dataset, **kwargs)
        
    def _create_buckets(self, num_buckets, bucket_boundaries):
        """
            Override to filter sequences by minimum length based on curriculum
        """
        # Calculate current minimum length threshold based on training progress
        progress = min(1.0, self.epoch / self.epochs_to_include_all)
        current_min_perc = max(1, int(self.min_perc_start * (1 - progress)))
        # Get lengths for all items
        percs = [(self.dataset.get_length_percentage(i), self.dataset.src_lengths[i]) for i in range(len(self.dataset))]
        assert (all(perc >= 0 and perc <= 100 for perc, length in percs)), "Length percentage must be between 0 and 100"
        # Filter indices by current minimum length
        valid_indices = [i for i, (perc, length) in enumerate(percs) if perc >= current_min_perc]
        # Only consider valid indices for bucketing
        valid_percs = [percs[i] for i in valid_indices]
        if not valid_percs:
            raise ValueError(f"No sequences meet the minimum percentage {current_min_perc} in epoch {self.epoch}")
        # Create bucket boundaries if not provided
        if bucket_boundaries is None:
            min_len = min(valid_percs, key=lambda x: x[1])[1]
            max_len = max(valid_percs, key=lambda x: x[1])[1]
            bucket_width = (max_len - min_len) / num_buckets
            bucket_boundaries = [min_len + bucket_width * i for i in range(1, num_buckets)]
        # Initialize buckets
        buckets = [[] for _ in range(len(bucket_boundaries) + 1)]
        # Assign valid indices to buckets
        for idx, (_, length) in enumerate(valid_percs):
            # Find the appropriate bucket
            bucket_idx = 0
            for boundary in bucket_boundaries:
                if length <= boundary:
                    break
                bucket_idx += 1
            buckets[bucket_idx].append(valid_indices[idx])
        # Filter out empty buckets
        buckets = [bucket for bucket in buckets if bucket]
        print(f'Epoch {self.epoch}, Min length: {current_min_perc}, Buckets: {len(buckets)}')
        return buckets
    
    def set_epoch(self, epoch):
        """Set the epoch for this sampler and rebuild buckets for the new curriculum stage."""
        self.epoch = epoch
        # Recreate buckets for the new epoch with updated curriculum
        self.buckets = self._create_buckets(self.num_buckets, self.bucket_boundaries)
        self.process_batches = self._process_batches_from_buckets()
        self.remaining_batches, self.batches_per_process = self._get_length_of_all_batches()