"""
Data Loading for FineWeb-Edu
Anonymous ICML 2026 Submission

This module implements streaming data loading from FineWeb-Edu dataset
as described in the paper (Section 4.2).
"""

import torch
from torch.utils.data import IterableDataset
from datasets import load_dataset


class StreamDataset(IterableDataset):
    """
    Streaming dataset for FineWeb-Edu corpus.
    
    Paper details (Section 4.2):
    - Dataset: FineWeb-Edu (Penedo et al., 2024)
    - Block size: 512 tokens
    - Training: 50k iterations, batch size 128 (effective)
    - Split: First 200k samples for training, rest for validation
    
    Args:
        tokenizer: HuggingFace tokenizer instance
        block_size (int): Sequence length (default: 512)
        skip_samples (int): Number of samples to skip (for train/val split)
        split (str): Dataset split ('train' or 'validation')
    """
    
    def __init__(self, tokenizer, block_size=512, skip_samples=0, split='train'):
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.skip_samples = skip_samples
        self.split = split
        
        print(f"Loading FineWeb-Edu dataset (split={split}, skip={skip_samples})...")
        
        # Load FineWeb-Edu from HuggingFace
        # Paper citation: Penedo et al., 2024
        self.dataset = load_dataset(
            "HuggingFaceFW/fineweb-edu",
            name="sample-10BT",  # 10B token sample (manageable size)
            split="train",
            streaming=True
        )
        
        # Skip samples for train/val split
        if skip_samples > 0:
            self.dataset = self.dataset.skip(skip_samples)
    
    def __iter__(self):
        """
        Yields batches of (input_ids, labels, future_tokens).
        
        Returns:
            input_ids (torch.Tensor): Input sequence [block_size]
            labels (torch.Tensor): Target tokens [block_size]
            future_tokens (torch.Tensor): Next 20 tokens for idea target [20]
        """
        buffer = []
        
        for sample in self.dataset:
            text = sample['text']
            
            # Tokenize
            tokens = self.tokenizer(
                text,
                truncation=False,
                add_special_tokens=False
            )['input_ids']
            
            buffer.extend(tokens)
            
            # Yield blocks when buffer is large enough
            # Need: block_size (input) + block_size (labels) + 20 (future)
            required_length = self.block_size + self.block_size + 20
            
            while len(buffer) >= required_length:
                # Extract sequences
                input_ids = torch.tensor(buffer[:self.block_size], dtype=torch.long)
                labels = torch.tensor(buffer[1:self.block_size+1], dtype=torch.long)
                future_tokens = torch.tensor(
                    buffer[self.block_size+1:self.block_size+21], 
                    dtype=torch.long
                )
                
                yield input_ids, labels, future_tokens
                
                # Slide window (overlap to maximize data usage)
                buffer = buffer[self.block_size:]


def create_idea_target(future_tokens, vocab_size, window_size, device):
    """
    Creates multi-hot target vector for semantic head training.
    
    Paper details (Section 3.6):
    - Window size K=20 tokens
    - Multi-hot encoding (Bag-of-Words)
    - Permutation invariant
    
    Args:
        future_tokens (torch.Tensor): Next K tokens [batch, K]
        vocab_size (int): Vocabulary size (32000 for Mistral)
        window_size (int): Number of future tokens (K=20)
        device (str): Device to create tensor on
        
    Returns:
        torch.Tensor: Multi-hot target [batch, seq_len, vocab_size]
    """
    batch_size = future_tokens.shape[0]
    seq_len = future_tokens.shape[1] if len(future_tokens.shape) > 1 else 1
    
    # Handle case where future_tokens is [batch, K] (single position)
    if len(future_tokens.shape) == 2 and future_tokens.shape[1] == window_size:
        # This is the expected case: [batch, 20]
        target = torch.zeros(batch_size, 1, vocab_size, device=device)
        
        # Mark each of the K future tokens as positive
        for i in range(window_size):
            if i < future_tokens.size(1):
                indices = future_tokens[:, i].unsqueeze(1)  # [batch, 1]
                target.scatter_(2, indices.unsqueeze(2), 1.0)
        
        # Expand to match sequence length if needed
        # In practice, this gets broadcast during loss computation
        return target.squeeze(1)  # [batch, vocab_size]
    
    # Handle case where future_tokens is [batch, seq_len, K]
    else:
        target = torch.zeros(batch_size, seq_len, vocab_size, device=device)
        
        for i in range(window_size):
            if i < future_tokens.size(-1):
                indices = future_tokens[..., i].unsqueeze(-1)
                target.scatter_(2, indices, 1.0)
        
        return target

