"""
Chunked Distributed Sampler that preserves sequential ordering.

Unlike PyTorch's DistributedSampler which interleaves samples across GPUs,
this sampler splits the dataset into consecutive chunks, preserving the
ordering provided by batch samplers (e.g., for curriculum learning).
"""

import math
import torch
from torch.utils.data import Sampler
from typing import Iterator, List


class ChunkedDistributedSampler(Sampler):
    """
    Sampler that splits dataset into consecutive chunks across GPUs.

    Example with 32 samples, 4 GPUs, batch_size=4:
        Query order: [1, 2, 3, 4, 5, 6, ..., 32]

        GPU 0: [1, 2, 3, 4], [17, 18, 19, 20]
        GPU 1: [5, 6, 7, 8], [21, 22, 23, 24]
        GPU 2: [9, 10, 11, 12], [25, 26, 27, 28]
        GPU 3: [13, 14, 15, 16], [29, 30, 31, 32]

    This ensures that:
    1. Each GPU processes consecutive batches from the original order
    2. All GPUs see different data in parallel
    3. The ordering from batch samplers is preserved
    """

    def __init__(
        self,
        dataset,
        num_replicas: int,
        rank: int,
        batch_size: int,
        shuffle: bool = False,
        seed: int = 0,
        drop_last: bool = True,
    ):
        """
        Args:
            dataset: Dataset to sample from
            num_replicas: Number of processes (GPUs)
            rank: Rank of the current process
            batch_size: Batch size per GPU
            shuffle: Whether to shuffle within each chunk (not recommended)
            seed: Random seed for shuffling
            drop_last: Whether to drop incomplete batches
        """
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.seed = seed
        self.drop_last = drop_last
        self.epoch = 0

        # Calculate how many samples per replica
        self.total_size = len(self.dataset)

        # Each "round" consists of (batch_size * num_replicas) samples
        # This ensures all GPUs get one batch per round
        self.samples_per_round = self.batch_size * self.num_replicas
        self.num_complete_rounds = self.total_size // self.samples_per_round

        if self.drop_last:
            # Only use complete rounds
            self.usable_size = self.num_complete_rounds * self.samples_per_round
            self.num_samples = self.num_complete_rounds * self.batch_size
        else:
            # Include partial last round
            # Pad to make it divisible
            self.usable_size = self.total_size
            samples_in_last_round = self.total_size % self.samples_per_round
            if samples_in_last_round > 0:
                # Pad to complete the round
                padding = self.samples_per_round - samples_in_last_round
                self.usable_size = self.total_size + padding

            self.num_samples = math.ceil(self.usable_size / self.num_replicas)

    def __iter__(self) -> Iterator[int]:
        """Generate indices for this rank."""
        # Start with sequential indices
        indices = list(range(self.total_size))

        # Pad if not dropping last
        if not self.drop_last and len(indices) < self.usable_size:
            # Repeat from beginning to pad
            padding = self.usable_size - len(indices)
            indices += indices[:padding]
        elif self.drop_last:
            # Truncate to complete rounds only
            indices = indices[: self.usable_size]

        # Now split into chunks for each GPU
        # For each round of (batch_size * num_replicas) samples:
        #   - GPU 0 gets samples [0:batch_size]
        #   - GPU 1 gets samples [batch_size:2*batch_size]
        #   - etc.

        my_indices = []
        for round_idx in range(len(indices) // self.samples_per_round):
            round_start = round_idx * self.samples_per_round
            my_chunk_start = round_start + (self.rank * self.batch_size)
            my_chunk_end = my_chunk_start + self.batch_size
            my_indices.extend(indices[my_chunk_start:my_chunk_end])

        return iter(my_indices)

    def __len__(self) -> int:
        """Number of samples for this rank."""
        return self.num_samples

    def set_epoch(self, epoch: int):
        """Set epoch for shuffling (if enabled)."""
        self.epoch = epoch


class ChunkedDistributedBatchSampler(Sampler):
    """
    Alternative implementation as a batch sampler.

    This directly yields batches instead of indices, which can be more intuitive.
    """

    def __init__(
        self,
        dataset,
        num_replicas: int,
        rank: int,
        batch_size: int,
        drop_last: bool = True,
    ):
        """
        Args:
            dataset: Dataset to sample from
            num_replicas: Number of processes (GPUs)
            rank: Rank of the current process
            batch_size: Batch size per GPU
            drop_last: Whether to drop incomplete batches
        """
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.epoch = 0

        # Total samples that can be cleanly divided
        self.total_size = len(self.dataset)
        self.samples_per_round = self.batch_size * self.num_replicas
        self.num_complete_rounds = self.total_size // self.samples_per_round

        if self.drop_last:
            self.num_batches = self.num_complete_rounds
        else:
            self.num_batches = math.ceil(self.total_size / self.samples_per_round)

    def __iter__(self) -> Iterator[List[int]]:
        """Generate batches for this rank."""
        indices = list(range(self.total_size))

        batches = []
        for round_idx in range(self.num_complete_rounds):
            round_start = round_idx * self.samples_per_round
            batch_start = round_start + (self.rank * self.batch_size)
            batch_end = batch_start + self.batch_size
            batches.append(indices[batch_start:batch_end])

        # Handle last incomplete round if not dropping
        if not self.drop_last:
            remaining_start = self.num_complete_rounds * self.samples_per_round
            if remaining_start < self.total_size:
                remaining = indices[remaining_start:]
                # Split remaining among GPUs
                my_start = self.rank * self.batch_size
                my_end = min(my_start + self.batch_size, len(remaining))
                if my_start < len(remaining):
                    batches.append(remaining[my_start:my_end])

        return iter(batches)

    def __len__(self) -> int:
        """Number of batches for this rank."""
        return self.num_batches

    def set_epoch(self, epoch: int):
        """Set epoch for compatibility."""
        self.epoch = epoch
