"""
Custom Data Samplers for Distributed Training

This module provides specialized data samplers designed for distributed training scenarios
in multimodal AI systems. The samplers extend PyTorch's standard sampling mechanisms with
enhanced functionality for epoch-based random sampling, distributed sequential sampling,
and distributed batch sampling with gradient accumulation support.

These samplers are essential for scaling training across multiple GPUs and nodes while
maintaining proper data distribution and reproducibility. They handle complex scenarios
including wrap-around sampling, gradient accumulation, and worker-specific batch allocation.

Key Features:
    - Epoch-aware random sampling for reproducible shuffling
    - Distributed sequential sampling with configurable iteration control
    - Distributed batch sampling with arbitrary base sampler wrapping
    - Gradient accumulation support for large effective batch sizes
    - Wrap-around handling for consistent batch sizes across workers
    - Worker-specific data partitioning for distributed training

Supported Sampling Strategies:
    - RandomSampler: Enhanced random sampling with epoch control
    - DistributedSequentialSampler: Sequential sampling for distributed training
    - DistributedBatchSampler: Batch-level distributed sampling wrapper

Dependencies:
    - torch: PyTorch framework for tensor operations and distributed utilities
    - numpy: Numerical operations for sampling calculations

Technical Background:
    Distributed training requires careful coordination of data sampling to ensure:
    1. Each worker processes different data subsets
    2. All data is covered across workers
    3. Reproducible results across training runs
    4. Efficient gradient accumulation support

Author: AI Model Development Team
License: MIT
"""

import math
import os
import sys

import torch
from torch.utils import data
import numpy as np


class RandomSampler(data.sampler.Sampler):
    """
    Enhanced RandomSampler with epoch-based seeding for reproducible distributed training.
    
    This sampler extends PyTorch's RandomSampler by adding epoch-aware seeding capabilities
    similar to DistributedSampler. It enables deterministic shuffling across training epochs
    while maintaining the flexibility of random sampling for single-node training scenarios.
    
    The sampler supports both sampling with and without replacement, and provides efficient
    batch-wise random index generation for large datasets. The epoch-based seeding ensures
    that the same epoch produces the same data order across different training runs.
    
    Key Features:
        - Epoch-based deterministic shuffling for reproducible training
        - Support for sampling with and without replacement
        - Efficient batch-wise index generation for performance
        - Compatible with standard PyTorch DataLoader interface
        - Dynamic dataset size handling for runtime dataset changes
    
    Technical Implementation:
        - Uses torch.Generator with manual seeding for deterministic randomness
        - Generates indices in batches of 32 for memory efficiency
        - Supports dynamic num_samples calculation based on dataset size
    
    Args:
        data_source (Dataset): Dataset to sample from
        replacement (bool, optional): Whether to sample with replacement. Defaults to False.
        num_samples (int, optional): Number of samples to draw. If None, uses len(dataset).
                                   Only valid when replacement=True.
    
    Raises:
        ValueError: If num_samples is specified with replacement=False
        ValueError: If num_samples is not a positive integer
        ValueError: If replacement is not a boolean value
    
    Example:
        >>> dataset = MyDataset()
        >>> sampler = RandomSampler(dataset)
        >>> sampler.set_epoch(42)  # Set epoch for deterministic shuffling
        >>> dataloader = DataLoader(dataset, sampler=sampler)
    """

    def __init__(self, data_source, replacement=False, num_samples=None):
        super(RandomSampler, self).__init__(data_source)
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.epoch = -1  # Initialize epoch to -1 for non-deterministic behavior

        # Validate parameter combinations
        if self._num_samples is not None and replacement is False:
            raise ValueError("With replacement=False, num_samples should not be specified, "
                             "since a random permute will be performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))
        if not isinstance(self.replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(self.replacement))

    @property
    def num_samples(self):
        """
        Get the number of samples to generate.
        
        Returns:
            int: Number of samples (dataset size if num_samples not specified)
        
        Note:
            Dataset size might change at runtime, so this property is computed dynamically.
        """
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        """
        Generate an iterator over sample indices.
        
        This method creates sample indices using either random permutation (without replacement)
        or random sampling (with replacement). The generation process is deterministic if an
        epoch has been set using set_epoch().
        
        Yields:
            int: Sample indices for the dataset
        
        Note:
            - For sampling with replacement: generates indices in batches of 32 for efficiency
            - For sampling without replacement: uses torch.randperm for complete shuffling
            - Uses manual seeding when epoch >= 0 for deterministic behavior
        """
        n = len(self.data_source)
        g = torch.Generator()
        
        # Set deterministic seed if epoch is specified
        if self.epoch >= 0:
            g.manual_seed(self.epoch)
            
        if self.replacement:
            # Generate indices with replacement in efficient batches
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist()
            # Handle remaining samples
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64,
                                     generator=g).tolist()
        else:
            # Generate complete random permutation without replacement
            yield from torch.randperm(n, generator=g).tolist()

    def __len__(self):
        """
        Return the number of samples in the sampler.
        
        Returns:
            int: Total number of samples that will be generated
        """
        return self.num_samples

    def set_epoch(self, epoch):
        """
        Set the epoch for deterministic shuffling.
        
        This method enables reproducible data ordering across training runs by setting
        a deterministic seed based on the epoch number. Call this method at the beginning
        of each epoch to ensure consistent data shuffling.
        
        Args:
            epoch (int): Current epoch number for seeding the random generator
        
        Note:
            Setting epoch >= 0 enables deterministic behavior. Use epoch = -1 to disable
            deterministic seeding and return to fully random behavior.
        """
        self.epoch = epoch


class DistributedSequentialSampler(data.sampler.Sampler):
    """
    Sequential sampler for distributed training with configurable iteration control.
    
    This sampler provides sequential data access across multiple workers in a distributed
    training setup. Unlike random sampling, it ensures deterministic, ordered data access
    which can be beneficial for certain training scenarios or debugging purposes.
    
    The sampler implements a cycling mechanism where each batch contains sequential indices
    with configurable bias offsets. This allows for controlled data distribution across
    workers while maintaining sequential ordering within each worker's subset.
    
    Key Features:
        - Sequential data access with deterministic ordering
        - Configurable iteration count independent of dataset size
        - Batch-wise bias offsets for distributed worker coordination
        - Cycling behavior for datasets smaller than total iterations
        - Worker-specific data partitioning for distributed training
    
    Technical Implementation:
        - Uses modulo arithmetic for cycling through dataset indices
        - Applies batch bias to ensure different starting points per batch element
        - Partitions batches across workers based on rank and world size
    
    Args:
        num_samples (int): Total number of samples in the dataset
        train_iters (int): Number of training iterations to perform
        batch_size (int): Size of each batch
        rank (int, optional): Worker rank in distributed setup. Defaults to -1 (single worker).
        world_size (int, optional): Total number of workers. Defaults to 2.
    
    Attributes:
        num_samples (int): Total dataset size
        rank (int): Current worker rank
        world_size (int): Total number of distributed workers
        start_iter (int): Starting iteration index (for resuming training)
        train_iters (int): Total number of training iterations
        batch_size (int): Batch size for sampling
        batch_bias (list): Bias offsets for each position in the batch
    
    Example:
        >>> sampler = DistributedSequentialSampler(
        ...     num_samples=10000,
        ...     train_iters=100,
        ...     batch_size=32,
        ...     rank=0,
        ...     world_size=4
        ... )
        >>> dataloader = DataLoader(dataset, batch_sampler=sampler)
    """
    
    def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2):
        super().__init__(num_samples)
        
        # Handle single-worker scenario
        if rank == -1:
            rank = 0
            world_size = 1
            
        self.num_samples = num_samples
        self.rank = rank
        self.world_size = world_size
        self.start_iter = 0  # For resuming training from checkpoints
        self.train_iters = train_iters
        self.batch_size = batch_size
        
        # Calculate bias offsets for sequential batch elements
        # Each position in batch starts from a different offset
        self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)]

    def __iter__(self):
        """
        Generate batches of sequential indices for distributed training.
        
        This method creates batches where each batch contains sequential indices with
        bias offsets. The cycling behavior ensures continuous data flow even when
        the number of iterations exceeds the dataset size.
        
        Yields:
            list: Batch of sample indices partitioned for the current worker
        
        Note:
            - Multiplies train_iters by 10 to provide extended cycling behavior
            - Applies modulo arithmetic to cycle through dataset indices
            - Partitions each batch based on worker rank and world size
        """
        for idx in range(self.start_iter, self.train_iters * 10):
            # Generate batch with bias offsets and cycling
            batch = [(idx + bias) % self.num_samples for bias in self.batch_bias]
            # Partition batch for current worker
            tbatch = self._batch(batch)
            yield tbatch

    def __len__(self):
        """
        Return the number of iterations this sampler will produce.
        
        Returns:
            int: Number of training iterations
        """
        return self.train_iters

    def _batch(self, batch):
        """
        Extract samples pertaining to this worker's portion of the batch.
        
        This method partitions the batch across distributed workers by calculating
        the appropriate slice based on the worker's rank and world size.
        
        Args:
            batch (list): Complete batch of sample indices
        
        Returns:
            list: Subset of batch indices for the current worker
        
        Note:
            Uses integer division to ensure even distribution across workers.
            Any remainder samples are distributed to the first few workers.
        """
        start = self.rank * self.batch_size // self.world_size
        end = (self.rank + 1) * self.batch_size // self.world_size
        return batch[start:end]


class DistributedBatchSampler(data.sampler.BatchSampler):
    """
    Distributed batch sampler that wraps arbitrary data samplers for multi-worker training.
    
    This sampler extends the standard BatchSampler to support distributed training scenarios
    by implementing batch-level distribution rather than sample-level distribution. This design
    allows wrapping of any base sampler (sequential, random, weighted, etc.) while providing
    distributed training capabilities.
    
    The sampler handles complex scenarios including gradient accumulation, wrap-around sampling
    for consistent batch sizes, and worker-specific batch partitioning. It ensures that all
    workers process complementary data subsets while maintaining the original sampling strategy.
    
    Key Features:
        - Wraps arbitrary base samplers (RandomSampler, SequentialSampler, etc.)
        - Batch-level distributed sampling for flexible sampler composition
        - Gradient accumulation support with configurable accumulation steps
        - Wrap-around handling for consistent training behavior
        - Worker-specific batch partitioning with proper load balancing
        - Resumable training with start iteration support
    
    Technical Implementation:
        - Maintains wrap-around state for consistent batch generation
        - Calculates effective batch size considering gradient accumulation
        - Implements iterator wrapping for seamless base sampler integration
        - Handles partial batches based on drop_last and wrap_last policies
    
    Args:
        sampler: Base data sampler to wrap (e.g., RandomSampler, SequentialSampler)
        batch_size (int): Size of each batch
        drop_last (bool): Whether to drop the last incomplete batch
        rank (int, optional): Worker rank in distributed setup. Defaults to -1.
        world_size (int, optional): Total number of workers. Defaults to 2.
        wrap_last (bool, optional): Whether to wrap around for the last batch. Defaults to False.
        gradient_accumulation_steps (int, optional): Number of gradient accumulation steps.
                                                   If specified, affects effective batch size calculation.
    
    Attributes:
        rank (int): Current worker rank
        world_size (int): Total number of distributed workers
        wrap_around (int): Current wrap-around offset for consistent sampling
        wrap_last (bool): Whether to apply wrap-around for incomplete batches
        start_iter (int): Starting iteration for resuming training
        effective_batch_size (int): Effective batch size considering gradient accumulation
    
    Raises:
        AssertionError: If rank is -1 in distributed training context (should not occur)
    
    Example:
        >>> base_sampler = RandomSampler(dataset)
        >>> distributed_sampler = DistributedBatchSampler(
        ...     sampler=base_sampler,
        ...     batch_size=32,
        ...     drop_last=True,
        ...     rank=0,
        ...     world_size=4,
        ...     gradient_accumulation_steps=2
        ... )
        >>> dataloader = DataLoader(dataset, batch_sampler=distributed_sampler)
    """
    
    def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None):
        super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
        
        # Validate distributed setup
        if rank == -1:
            assert False, 'should not be here'
            
        self.rank = rank
        self.world_size = world_size
        
        # Initialize wrap-around state for consistent sampling
        self.sampler.wrap_around = 0
        self.wrap_around = 0
        self.wrap_last = wrap_last
        self.start_iter = 0  # For resuming training from checkpoints
        
        # Calculate effective batch size for gradient accumulation
        self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps

    def __iter__(self):
        """
        Generate distributed batches from the wrapped base sampler.
        
        This method creates batches by accumulating samples from the base sampler and
        distributing them across workers. It handles wrap-around logic, gradient
        accumulation considerations, and partial batch policies.
        
        Yields:
            list: Batch of sample indices for the current worker
        
        Note:
            - Accumulates samples into batches of the specified size
            - Applies worker-specific partitioning using _batch method
            - Handles start_iter for resuming training from checkpoints
            - Manages wrap-around state for consistent batch generation
        """
        batch = []
        i = 0
        
        # Iterate through base sampler with wrap-around handling
        for idx in self.data_iterator(self.sampler, wrap_around=False):
            batch.append(idx)
            
            # Process complete batch
            if len(batch) == self.batch_size:
                tbatch = self._batch(batch)
                # Skip batches before start_iter (for resuming training)
                if i >= self.start_iter * self.effective_batch_size:
                    yield tbatch
                    self.start_iter = 0  # Reset after first valid batch
                i += len(batch)
                batch = []
        
        # Handle incomplete final batch
        batch_len = len(batch)
        if batch_len > 0 and not self.drop_last:
            if self.wrap_last:
                # Update wrap-around state for next epoch
                self.sampler.wrap_around -= self.batch_size
                self.wrap_around += batch_len
                self.wrap_around %= self.batch_size
            yield self._batch(batch)
            
        # Restore wrap-around state if wrap_last is enabled
        if self.wrap_last:
            self.sampler.wrap_around += self.batch_size

    def data_iterator(self, _iter, wrap_around=False):
        """
        Create an iterator over the base sampler with wrap-around handling.
        
        This method wraps the base sampler's iterator to handle wrap-around logic
        for consistent batch generation across epochs. It skips samples based on
        the current wrap-around offset to maintain proper alignment.
        
        Args:
            _iter: Base sampler iterator to wrap
            wrap_around (bool): Whether to apply wrap-around logic
        
        Yields:
            Sample indices from the base sampler, properly offset for wrap-around
        
        Note:
            - Skips initial samples based on wrap_around offset
            - Updates wrap_around state when wrap_around=True
            - Ensures consistent sample ordering across distributed workers
        """
        for i, idx in enumerate(_iter):
            # Skip samples for wrap-around alignment
            if i < self.wrap_around % self.batch_size:
                continue
                
            # Update wrap-around state if enabled
            if wrap_around:
                self.wrap_around += 1
                self.wrap_around %= self.batch_size
                
            yield idx

    def _batch(self, batch):
        """
        Extract samples pertaining to this worker's portion of the batch.
        
        This method partitions the batch across distributed workers by calculating
        the appropriate slice based on the worker's rank and world size. It ensures
        that each worker processes a disjoint subset of the batch data.
        
        Args:
            batch (list): Complete batch of sample indices
        
        Returns:
            list: Subset of batch indices assigned to the current worker
        
        Note:
            Uses integer division to distribute batch elements evenly across workers.
            Any remainder elements are assigned to the first few workers based on rank.
        """
        start = self.rank * self.batch_size // self.world_size
        end = (self.rank + 1) * self.batch_size // self.world_size
        return batch[start:end]
