import torch
from torch.utils.data import Sampler
import numpy as np
from typing import List, Iterator, Optional

class DistributedBucketBatchSampler(Sampler):
    """
    Distributed bucket batch sampler that groups sequences of similar lengths
    and distributes them across multiple GPUs.
    
    Args:
        dataset: Dataset to sample from
        batch_sizes: Optional list of batch sizes for each bucket
        num_buckets: Number of length buckets to create
        bucket_boundaries: Optional explicit bucket boundaries
        drop_last: If True, drop the last incomplete batch
        shuffle: Whether to shuffle sequences within buckets
        seed: Random seed for shuffling
        rank: Rank of the current process
        num_replicas: Number of processes participating in distributed training
    """
    def __init__(
        self, 
        dataset,
        batch_size: int = None,
        batch_sizes: Optional[List[int]] = None,
        drop_last: bool = False,
        num_buckets: int = 5,
        bucket_boundaries: Optional[List[int]] = None,
        length_key_fn=None,
        shuffle: bool = True,
        seed: int = None,
        rank: int = 0,
        num_replicas: int = 1
    ):
        self.num_buckets = num_buckets
        self.bucket_boundaries = bucket_boundaries
        self.dataset = dataset
        # Handle fixed vs variable batch sizes
        if batch_sizes is not None:
            self.use_variable_batching = True
            self.batch_sizes = batch_sizes
        else:
            self.use_variable_batching = False
            self.batch_size = batch_size
            
        self.drop_last = drop_last
        self.shuffle = shuffle
        self.seed = seed
        self.rank = rank
        self.num_replicas = num_replicas
        self.epoch = 0
        
        # For reproducibility across epochs and processes
        if seed is not None:
            self.base_seed = seed
        else:
            self.base_seed = 0
        
        # Default length extractor gets source sequence length
        if length_key_fn is None:
            self.length_key_fn = lambda x: torch.sum(x[0] != dataset.pad_idx).item()
        else:
            self.length_key_fn = length_key_fn
            
        # Create buckets
        self.buckets = self._create_buckets(num_buckets, bucket_boundaries)
        self.process_batches = self._process_batches_from_buckets()
        self.remaining_batches, self.batches_per_process = self._get_length_of_all_batches()

    def _create_buckets(self, num_buckets, bucket_boundaries):
        """Create buckets with dataset indices grouped by sequence length"""
        # Get lengths for all items
        lengths = [self.length_key_fn(self.dataset[i]) for i in range(len(self.dataset))]
        
        # Create bucket boundaries if not provided
        if bucket_boundaries is None:
            min_len = min(lengths)
            max_len = max(lengths)
            bucket_width = (max_len - min_len) / num_buckets
            bucket_boundaries = [min_len + bucket_width * i for i in range(1, num_buckets)]
        #print(f'bucket_boundaries: {bucket_boundaries}')
        
        # Initialize buckets
        buckets = [[] for _ in range(len(bucket_boundaries) + 1)]
        
        # Assign indices to buckets
        for idx, length in enumerate(lengths):
            # Find the appropriate bucket
            bucket_idx = 0
            for boundary in bucket_boundaries:
                if length <= boundary:
                    break
                bucket_idx += 1
            buckets[bucket_idx].append(idx)
            
        # Filter out empty buckets
        buckets = [bucket for bucket in buckets if bucket]
        
        #print(f'Rank {self.rank}: Created {len(buckets)} non-empty buckets')
        return buckets
    
    def _process_batches_from_buckets(self):
        # Set seed for this epoch and this process
        epoch_seed = self.base_seed + self.epoch
        rng = np.random.RandomState(epoch_seed)
        
        # Shuffle within buckets if specified
        if self.shuffle:
            for bucket in self.buckets:
                # Use a different seed for each bucket to avoid correlation
                bucket_seed = epoch_seed + hash(tuple(bucket)) % 1000000
                bucket_rng = np.random.RandomState(bucket_seed)
                bucket_rng.shuffle(bucket)
        
        # Create batches from each bucket
        all_batches = []
        for bucket_idx, bucket in enumerate(self.buckets):
            # Determine batch size for this bucket
            if self.use_variable_batching:
                # If we have fewer bucket_sizes than buckets, use the last one for remaining buckets
                bs_idx = min(bucket_idx, len(self.batch_sizes) - 1)
                current_batch_size = self.batch_sizes[bs_idx]
            else:
                current_batch_size = self.batch_size
                
            # Create batches of indices from this bucket
            for i in range(0, len(bucket), current_batch_size):
                batch = bucket[i:i + current_batch_size]
                if len(batch) < current_batch_size and self.drop_last:
                    continue
                all_batches.append(batch)
        
        # Shuffle batches across buckets if requested
        if self.shuffle:
            rng.shuffle(all_batches)
        
        # Distribute batches among processes
        num_batches = len(all_batches)
        num_batches_per_replica = (num_batches + self.num_replicas - 1) // self.num_replicas
        
        # Calculate indices for this process
        start_idx = self.rank * num_batches_per_replica
        end_idx = min(start_idx + num_batches_per_replica, num_batches)
        
        # Only yield batches assigned to this process
        process_batches = all_batches[start_idx:end_idx]

        return process_batches
    
    def __iter__(self) -> Iterator[List[int]]:
        #print(f"Rank {self.rank}: Assigned {len(self.process_batches)} batches out of {len(self.buckets)} total")
        for batch in self.process_batches:
            yield batch
    
    def _get_length_of_all_batches(self):
        # Count total batches
        total_batches = 0
        for bucket_idx, bucket in enumerate(self.buckets):
            # Determine batch size for this bucket
            if self.use_variable_batching:
                bs_idx = min(bucket_idx, len(self.batch_sizes) - 1)
                current_batch_size = self.batch_sizes[bs_idx]
            else:
                current_batch_size = self.batch_size
                
            # Count batches in this bucket
            if self.drop_last:
                total_batches += len(bucket) // current_batch_size
            else:
                total_batches += (len(bucket) + current_batch_size - 1) // current_batch_size
        
        # Calculate batches per process, rounded up
        batches_per_process = (total_batches + self.num_replicas - 1) // self.num_replicas
        
        # The last process might have fewer batches
        remaining_batches = total_batches - batches_per_process * (self.num_replicas - 1)

        return remaining_batches, batches_per_process
    
    def __len__(self) -> int:
        # Return the number of batches for this process
        if self.rank == self.num_replicas - 1:
            return self.remaining_batches
        else:
            return self.batches_per_process
            
    def set_epoch(self, epoch):
        """Set the epoch for this sampler to ensure deterministic shuffling."""
        self.epoch = epoch