"""Custom samplers for training."""

from typing import List
import torch
from torch.utils.data import Sampler


class GroupedRandomSampler(Sampler[int]):
    """Single-process sampler that shuffles data in contiguous groups of size group_size.
    
    This sampler is useful when you want to maintain some locality in batches while still
    introducing randomness. For example, with group_size=4 and dataset [0,1,2,3,4,5,6,7],
    it might return [1,3,0,2,6,4,7,5] - shuffling within groups [0,1,2,3] and [4,5,6,7].
    """
    
    def __init__(self, n, args, num_completions_per_prompt: int = None, sampler_group_size: int = None):
        """
        Args:
            n: Number of items to sample from
            group_size: Size of each group to shuffle independently
            seed: Random seed for reproducible shuffling
            drop_last: Whether to drop the last incomplete group
            num_completions_per_prompt: If provided, randomly split completions per prompt into groups each epoch
        """
        self.n = n
        self.args = args
        self.sampler_group_size = sampler_group_size
        if self.sampler_group_size is None:
            self.sampler_group_size = num_completions_per_prompt
        print(f"GroupedRandomSampler configured with group_size={self.sampler_group_size}")
        self.seed = getattr(args, 'seed', 42)
        print(f"GroupedRandomSampler configured with seed={self.seed}")
        self.num_completions_per_prompt = num_completions_per_prompt
        self.generator = torch.Generator()
        self.generator.manual_seed(self.seed)
        self.epoch = 0
        
    def __iter__(self):
        """Generate shuffled indices grouped by group_size."""
        n = self.n
        indices = list(range(n))
        
        if self.num_completions_per_prompt is None:
            raise ValueError("num_completions_per_prompt must be set")

        shuffled_indices = self._dynamic_grouping_per_prompt(indices)
        
        return iter(shuffled_indices)
    
    def _dynamic_grouping_per_prompt(self, indices: List[int]) -> List[int]:
        """Dynamic grouping - randomly split completions per prompt into groups each epoch with global group shuffling."""
        n = len(indices)
        num_prompts = n // self.num_completions_per_prompt
        
        if n % self.num_completions_per_prompt != 0:
            raise ValueError(f"Dataset size {n} is not divisible by num_completions_per_prompt {self.num_completions_per_prompt}")
        
        all_groups = []
        
        for prompt_idx in range(num_prompts):
            # Get all completions for this prompt
            prompt_start = prompt_idx * self.num_completions_per_prompt
            prompt_end = prompt_start + self.num_completions_per_prompt
            prompt_completions = indices[prompt_start:prompt_end]
            
            # Randomly shuffle completions for this prompt (changes group membership each epoch)
            prompt_tensor = torch.tensor(prompt_completions, dtype=torch.long)
            shuffled_prompt_completions = prompt_tensor[torch.randperm(len(prompt_tensor), generator=self.generator)].tolist()
            
            # Split into groups of group_size
            num_complete_groups = len(shuffled_prompt_completions) // self.sampler_group_size
            
            for group_idx in range(num_complete_groups):
                group_start = group_idx * self.sampler_group_size
                group_end = group_start + self.sampler_group_size
                group_indices = shuffled_prompt_completions[group_start:group_end]
                
                # Shuffle within this group
                group_tensor = torch.tensor(group_indices, dtype=torch.long)
                shuffled_group = group_tensor[torch.randperm(len(group_tensor), generator=self.generator)]
                all_groups.append(shuffled_group.tolist())
        
        # Shuffle the order of all groups globally (across prompts)
        shuffled_indices = []
        if all_groups:
            group_order = torch.randperm(len(all_groups), generator=self.generator)
            shuffled_groups = [all_groups[i] for i in group_order]
            shuffled_indices = [idx for group in shuffled_groups for idx in group]
        
        return shuffled_indices
    
    def set_epoch(self, epoch: int):
        """Set the epoch for this sampler to ensure different randomness across epochs."""
        self.epoch = epoch
        self.generator.manual_seed(self.seed + epoch)
    
    def __len__(self):
        """Return the number of samples."""
        return self.n