import os
import multiprocessing
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from typing import Optional, Union, Dict, List, Any
from datasets import load_from_disk, load_dataset, Dataset
import numpy as np
# At the top of your data_finetune_.py file, add this import:
from datasets import Dataset, IterableDataset, load_dataset
from PIL import Image
import hashlib
import json
import time
from tqdm import tqdm
import glob

class CaptionDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=40, augmentation_prob=0.0):
        """
        Args:
            data: Dataset containing image captions
            tokenizer: HuggingFace tokenizer
            max_length: Maximum sequence length
            augmentation_prob: Probability of masking a token for augmentation
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augmentation_prob = augmentation_prob
        
        # Extract just the captions from the dataset
        if hasattr(data, 'features') and 'caption' in data.features:
            # Handle datasets.Dataset object
            texts = data['caption']
        elif isinstance(data, dict) and 'caption' in data:
            texts = data['caption']
        elif isinstance(data, list) and all(isinstance(item, dict) and 'caption' in item for item in data):
            texts = [item['caption'] for item in data]
        else:
            # Try to handle direct strings or other formats
            if isinstance(data, (list, np.ndarray)) and all(isinstance(item, str) for item in data):
                texts = data
            else:
                raise ValueError("Dataset must contain 'caption' field or be a list of strings")
        
        # Tokenize captions
        print(f"Tokenizing {len(texts)} captions...")
        tokenized = self.tokenizer(
            texts,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors=None
        )
        
        # Convert to plain Python lists to avoid indexing issues
        self.input_ids = [ids for ids in tokenized['input_ids']]
        self.attention_mask = [mask for mask in tokenized['attention_mask']]
        print("Tokenization complete.")
    
    def create_augmentation_mask(self, attention_mask):
        """Create a mask for token augmentation (masking)."""
        device = attention_mask.device
        aug_mask = (torch.rand(attention_mask.shape, device=device) > self.augmentation_prob)
        # Make sure we keep padding tokens masked
        aug_mask = aug_mask & attention_mask
        return aug_mask
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        # Handle both single item and batch (list) indexing
        if isinstance(idx, list):
            return self._get_batch(idx)
        
        # Single item indexing
        input_ids = torch.tensor(self.input_ids[idx], dtype=torch.long)
        attention_mask = torch.tensor(self.attention_mask[idx], dtype=torch.bool)
        
        # Ensure proper shapes
        if len(input_ids.shape) == 1:
            input_ids = input_ids[:self.max_length]
            attention_mask = attention_mask[:self.max_length]
            
            # Pad if necessary
            if input_ids.shape[0] < self.max_length:
                pad_length = self.max_length - input_ids.shape[0]
                input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.tokenizer.pad_token_id)
                attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=False)
        
        # Create augmented version
        aug_attention_mask = self.create_augmentation_mask(attention_mask)
        aug_input_ids = input_ids.clone()
        aug_input_ids = torch.where(aug_attention_mask, aug_input_ids, self.tokenizer.mask_token_id)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'target_ids': input_ids.clone(),
            'aug_ids': aug_input_ids,
            'aug_attention_mask': aug_attention_mask
        }
    
    def _get_batch(self, indices):
        """Handle batch indexing directly."""
        batch_input_ids = []
        batch_attention_mask = []
        
        for idx in indices:
            batch_input_ids.append(self.input_ids[idx])
            batch_attention_mask.append(self.attention_mask[idx])
        
        # Convert to tensors with padding
        batch_input_ids = [torch.tensor(ids, dtype=torch.long)[:self.max_length] for ids in batch_input_ids]
        batch_attention_mask = [torch.tensor(mask, dtype=torch.bool)[:self.max_length] for mask in batch_attention_mask]
        
        # Pad if necessary
        batch_input_ids = [
            torch.nn.functional.pad(ids, (0, self.max_length - ids.shape[0]), value=self.tokenizer.pad_token_id)
            if ids.shape[0] < self.max_length else ids
            for ids in batch_input_ids
        ]
        batch_attention_mask = [
            torch.nn.functional.pad(mask, (0, self.max_length - mask.shape[0]), value=False)
            if mask.shape[0] < self.max_length else mask
            for mask in batch_attention_mask
        ]
        
        # Stack into tensors
        batch_input_ids = torch.stack(batch_input_ids)
        batch_attention_mask = torch.stack(batch_attention_mask)
        
        # Create augmented versions
        batch_aug_attention_mask = self.create_augmentation_mask(batch_attention_mask)
        batch_aug_input_ids = batch_input_ids.clone()
        batch_aug_input_ids = torch.where(batch_aug_attention_mask, batch_aug_input_ids, self.tokenizer.mask_token_id)
        
        return {
            'input_ids': batch_input_ids,
            'attention_mask': batch_attention_mask,
            'target_ids': batch_input_ids.clone(),
            'aug_ids': batch_aug_input_ids,
            'aug_attention_mask': batch_aug_attention_mask
        }
        
class PixmoCapDataModule(pl.LightningDataModule):

    def __init__(

        self,

        batch_size: int = 32,

        max_length: int = 30,

        num_workers: Optional[int] = 24,

        augmentation_prob: float = 0.0,

        cache_dir: str = "/home/user/datasets/wiki_cache",

        tokenizer_path: str = None,  # Path to local tokenizer

        max_examples: Optional[Dict[str, int]] = None,

        train_val_test_split: List[float] = [0.8, 0.1, 0.1],  # Default split ratios

        dataset_fraction: float = 1.0,  # Fraction of dataset to load (0.0-1.0)

        streaming: bool = False,  # Whether to use streaming for large datasets

    ):

        """

        Args:

            batch_size: Batch size for dataloaders

            max_length: Maximum sequence length

            num_workers: Number of dataloader workers

            augmentation_prob: Probability of masking a token for augmentation

            cache_dir: Cache directory for HuggingFace

            tokenizer_path: Path to local tokenizer files (if None, use ModernBERT from HF)

            max_examples: Optional dict with limits for each split

            train_val_test_split: Ratios for train/val/test splits (must sum to 1)

            dataset_fraction: Fraction of the dataset to load (0.0-1.0)

            streaming: Whether to use streaming mode for large datasets

        """

        super().__init__()

        self.batch_size = batch_size

        self.max_length = max_length

        self.num_workers = num_workers or max(1, multiprocessing.cpu_count() // 2)

        self.augmentation_prob = augmentation_prob

        self.cache_dir = cache_dir

        self.tokenizer_path = tokenizer_path

        self.max_examples = max_examples or {"train": None, "validation": None, "test": None}

        self.split_ratios = train_val_test_split

        self.dataset_fraction = min(max(0.0, dataset_fraction), 1.0)

        self.streaming = streaming

        

        # Validate split ratios

        if abs(sum(self.split_ratios) - 1.0) > 1e-5:

            raise ValueError(f"Split ratios must sum to 1, got {self.split_ratios}")

        

        # Load tokenizer

        self._load_tokenizer()

        

        # Ensure tokenizer has all required special tokens

        special_tokens = {

            'pad_token': '[PAD]' if self.tokenizer.pad_token is None else self.tokenizer.pad_token,

            'unk_token': '[UNK]' if self.tokenizer.unk_token is None else self.tokenizer.unk_token,

            'mask_token': '[MASK]' if self.tokenizer.mask_token is None else self.tokenizer.mask_token,

            'bos_token': '[CLS]' if self.tokenizer.bos_token is None else self.tokenizer.bos_token,

            'eos_token': '[SEP]' if self.tokenizer.eos_token is None else self.tokenizer.eos_token

        }

        self.tokenizer.add_special_tokens(special_tokens)

    

    def _load_tokenizer(self):

        """Load the tokenizer from local path or HuggingFace."""

        if self.tokenizer_path and os.path.exists(self.tokenizer_path):

            print(f"Loading tokenizer from local path: {self.tokenizer_path}")

            self.tokenizer = AutoTokenizer.from_pretrained(

                self.tokenizer_path,

                local_files_only=True,

                trust_remote_code=False

            )

        else:

            print("Loading tokenizer from HuggingFace")

            try:

                self.tokenizer = AutoTokenizer.from_pretrained(

                    "answerdotai/ModernBERT-base",

                    cache_dir=self.cache_dir

                )

            except Exception as e:

                raise ValueError(f"Could not load#  tokenizer: {e}. Please prov# ide a valid local tokenizer path.")

    

#    def prepare_data(s# elf):

      #   """Download and prepare the dataset from HuggingFace."""

 #       try:
# 
  #          print("Loading dataset from HuggingFace: all# enai/pixmo-cap")

            # Just check if we can load the dataset info

  #           dataset_info = datasets.load_dataset_builder(# "allenai/pixmo-cap")

         #    print(f"Dataset info loaded successfully")

    #     except Exception as e:

   #         print(f"Error accessing dataset: {e}")

    #        raise

    

    def setup(self, stage: Optional[str] = None):

        """Set up datasets for train/val/test by loading from HuggingFace and creating splits."""

        # Load the dataset from HuggingFace

        print("Loading dataset from HuggingFace: allenaap")

        try:

            if self.streaming:

                full_dataset = load_dataset(

                    "allenai/pixmo-cap", 

                    split="train",

                    cache_dir=self.cache_dir,

                    streaming=True

                )

                # For streaming, we need to handle differently

                print("Streaming mode not fully implemented - falling back to regular loading")

                seing = False

            

            if not self.streaming:

                full_dataset = load_dataset(

                    "allenai/pixmo-cap", 

                    split="train",

                    cache_dir=self.cache_dir

                )

                print(f"Successfully loaded dataset with {len(full_dataset)} examples")

                

        except Exception as e:

            print(f"Error loading dataset from HuggingFace: {e}")

            raise

        

        # Apply dataset fraction if needed

        if self.dataset_fraction < 1.0:

            total_examples = len(full_dataset)

            num_examples = int(total_examples * self.dataset_fraction)

            print(f"Using {self.dataset_fraction:.1%} of the dataset: {num_examples} out of {total_examples} examples")

            full_dataset = full_dataset.shuffle(seed=42)

            full_dataset = full_dataset.select(range(num_examples))

        

        # Apply global max examples limit if specified

        max_total = sum(limit for limit in self.max_examples.values() if limit is not None)

        if max_total and max_total < len(full_dataset):

            full_dataset = full_dataset.select(range(min(max_total, len(full_dataset))))

            print(f"Limited dataset to {len(full_dataset)} examples")

        

        # Calculate split sizes

        dataset_size = len(full_dataset)

        train_size = int(dataset_size * self.split_ratios[0])

        val_size = int(dataset_size * self.split_ratios[1])

        test_size = dataset_size - train_size - val_size

        

        print(f"Creating splits from {dataset_size} examples:")

        print(f"  - Train: {train_size} examples")

        print(f"  - Validation: {val_size} examples")

        print(f"  - Test: {test_size} examples")

        

        # Create reproducible splits

        seed = 42

        splits = full_dataset.train_test_split(

            test_size=val_size + test_size,

            seed=seed

        )

        train_dataset = splits['train']

        

        # Further split the test portion into validation and test

        if test_size > 0:

            remaining_splits = splits['test'].train_test_split(

                test_size=test_size / (val_size + test_size),

                seed=seed

            )

            val_dataset = remaining_splits['train']

            test_dataset = remaining_splits['test']

        else:

            val_dataset = splits['test']

            test_dataset = splits['test'].select(range(min(100, len(splits['test']))))

        

        # Apply per-split max examples limit if specified

        if self.max_examples["train"] is not None:

            train_dataset = train_dataset.select(range(min(self.max_examples["train"], len(train_dataset))))

        if self.max_examples["validation"] is not None:

            val_dataset = val_dataset.select(range(min(self.max_examples["validation"], len(val_dataset))))

        if self.max_examples["test"] is not None:

            test_dataset = test_dataset.select(range(min(self.max_examples["test"], len(test_dataset))))

        

        # Debug dataset structure

        print(f"Dataset features: {full_dataset.features}")

        print(f"Dataset columns: {full_dataset.column_names}")

        sample = full_dataset[0]

        print(f"Sample keys: {sample.keys() if isinstance(sample, dict) else type(sample)}")

        

        # Create dataset objects for training stages

        if stage == 'fit' or stage is None:

            self.train_dataset = CaptionDataset(

                train_dataset, 

                self.tokenizer, 

                max_length=self.max_length,

                augmentation_prob=self.augmentation_prob

            )

            

            self.val_dataset = CaptionDataset(

                val_dataset, 

                self.tokenizer, 

                max_length=self.max_length,

                augmentation_prob=0.0  # No augmentation for validation

            )

        

        if stage == 'test' or stage is None:

            self.test_dataset = CaptionDataset(

                test_dataset, 

                self.tokenizer, 

                max_length=self.max_length,

                augmentation_prob=0.0  # No augmentation for test

            )

        

        # Print final dataset sizes

        if stage is None:

            print(f"Final dataset sizes:")

            print(f"  Train: {len(self.train_dataset):,}")

            print(f"  Validation: {len(self.val_dataset):,}")

            print(f"  Test: {len(self.test_dataset):,}")



    def train_dataloader(self):

        return DataLoader(

            self.train_dataset,

            batch_size=self.batch_size,

            shuffle=True,

            num_workers=self.num_workers,

            pin_memory=True,

            persistent_workers=True if self.num_workers > 0 else False,

            prefetch_factor=2 if self.num_workers > 0 else None

        )

    

    def val_dataloader(self):

        return DataLoader(

            self.val_dataset,

            batch_size=self.batch_size,

            shuffle=False,

            num_workers=self.num_workers,

            pin_memory=True,

            persistent_workers=True if self.num_workers > 0 else False,

            prefetch_factor=2 if self.num_workers > 0 else None

        )

    

    def test_dataloader(self):

        return DataLoader(

            self.test_dataset,

            batch_size=self.batch_size,

            shuffle=False,

            num_workers=self.num_workers,

            pin_memory=True,

            persistent_workers=True if self.num_workers > 0 else False,

            prefetch_factor=2 if self.num_workers > 0 else None

        )



    @property

    def vocab_size(self):

        return len(self.tokenizer)

class TextReconstructionDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_path: str = "reconstruction_dataset/hf_dataset",
        tokenizer_name: str = "answerdotai/ModernBERT-base",
        batch_size: int = 32,
        max_length: int = 40,
        num_workers: Optional[int] = None,
        augmentation_prob: float = 0.1,
        cache_dir: str = None,
        dataset_type: str = "hf",  # Options: "hf", "csv", "numpy", "torch"
        limit_examples: Optional[Dict[str, int]] = None
    ):
        """
        Args:
            dataset_path: Path to the dataset 
            tokenizer_name: HuggingFace tokenizer name or path
            batch_size: Batch size for dataloaders
            max_length: Maximum sequence length
            num_workers: Number of dataloader workers (defaults to cpu_count)
            augmentation_prob: Probability of masking a token for augmentation
            cache_dir: Cache directory for HuggingFace
            dataset_type: Type of dataset loading method ("hf", "csv", "numpy", "torch")
            limit_examples: Optional dict with 'train', 'validation', 'test' keys to limit examples
        """
        super().__init__()
        self.dataset_path = dataset_path
        self.tokenizer_name = tokenizer_name
        self.batch_size = batch_size
        self.max_length = max_length
        self.num_workers = num_workers or max(1, multiprocessing.cpu_count() // 2)
        self.augmentation_prob = augmentation_prob
        self.cache_dir = cache_dir
        self.dataset_type = dataset_type
        self.limit_examples = limit_examples or {"train": None, "validation": None, "test": None}
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            cache_dir=cache_dir
        )
        
        # Ensure tokenizer has all required special tokens
        special_tokens = {
            'pad_token': '[PAD]' if self.tokenizer.pad_token is None else self.tokenizer.pad_token,
            'unk_token': '[UNK]' if self.tokenizer.unk_token is None else self.tokenizer.unk_token,
            'mask_token': '[MASK]' if self.tokenizer.mask_token is None else self.tokenizer.mask_token,
            'bos_token': '[CLS]' if self.tokenizer.bos_token is None else self.tokenizer.bos_token,
            'eos_token': '[SEP]' if self.tokenizer.eos_token is None else self.tokenizer.eos_token
        }
        self.tokenizer.add_special_tokens(special_tokens)
        
    def _load_dataset(self, split):
        """Load dataset based on specified type and split."""
        limit = self.limit_examples.get(split)
        
        if self.dataset_type == "hf":
            # Load HuggingFace dataset
            dataset = load_from_disk(self.dataset_path)
            data = dataset[split]
            if limit:
                data = data.select(range(min(limit, len(data))))
            return data
            
        elif self.dataset_type == "csv":
            # Load CSV data
            import pandas as pd
            df = pd.read_csv(os.path.join(os.path.dirname(self.dataset_path), f"{split}.csv"))
            if limit:
                df = df.head(limit)
            return {"text": df["text"].tolist()}
            
        elif self.dataset_type == "numpy":
            # Load NumPy data
            import numpy as np
            texts = np.load(os.path.join(os.path.dirname(self.dataset_path), f"{split}_texts.npy"))
            if limit:
                texts = texts[:limit]
            return {"text": texts.tolist()}
            
        elif self.dataset_type == "torch":
            # Load tokenized PyTorch tensors
            data = torch.load(os.path.join(os.path.dirname(self.dataset_path), f"{split}_tokenized.pt"))
            # Already tokenized, so return directly
            if limit:
                data = {k: v[:limit] for k, v in data.items()}
            return data
            
        else:
            raise ValueError(f"Unsupported dataset type: {self.dataset_type}")

    def setup(self, stage: Optional[str] = None):
        """Set up datasets for train/val/test."""
        # Set up training data
        if stage == 'fit' or stage is None:
            print("Loading training data...")
            train_data = self._load_dataset('train')
            self.train_dataset = TextReconstructionDataset(
                train_data, 
                self.tokenizer, 
                max_length=self.max_length,
                augmentation_prob=self.augmentation_prob
            )
            
            print("Loading validation data...")
            val_data = self._load_dataset('validation')
            self.val_dataset = TextReconstructionDataset(
                val_data, 
                self.tokenizer, 
                max_length=self.max_length,
                augmentation_prob=self.augmentation_prob
            )
        
        # Set up test data
        if stage == 'test' or stage is None:
            print("Loading test data...")
            test_data = self._load_dataset('test')
            self.test_dataset = TextReconstructionDataset(
                test_data, 
                self.tokenizer, 
                max_length=self.max_length,
                augmentation_prob=self.augmentation_prob
            )
        
        # Print dataset sizes
        if stage is None:
            print(f"Dataset sizes:")
            print(f"  Train: {len(self.train_dataset):,}")
            print(f"  Validation: {len(self.val_dataset):,}")
            print(f"  Test: {len(self.test_dataset):,}")

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False,
            prefetch_factor=2 if self.num_workers > 0 else None
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False,
            prefetch_factor=2 if self.num_workers > 0 else None
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False,
            prefetch_factor=2 if self.num_workers > 0 else None
        )

    @property
    def vocab_size(self):
        return len(self.tokenizer)

def tokenize_function(texts, tokenizer, max_length):
    # must accept plain arguments, not 'self'
    tokenized = tokenizer(
        texts,
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors=None
    )
    return tokenized

"/home/user/Paper2/coco_dataset/annotations/captions_val2017.json"

class Coco2017DataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int = 32,
        max_length: int = 30,
        num_workers: int = 8,
        cache_dir: str = '/home/user/datasets/wiki_cache'
    ):
        super().__init__()
        self.batch_size = batch_size
        self.max_length = max_length
        self.num_workers = num_workers
        self.cache_dir = cache_dir
        
        # Load ModernBERT tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            "answerdotai/ModernBERT-base",
            cache_dir=cache_dir
        )
        
        # Add special tokens
        special_tokens = {
            'pad_token': '[PAD]',
            'unk_token': '[UNK]',
            'bos_token': '[CLS]',  # Use [CLS] as BOS token
            'eos_token': '[SEP]'   # Use [SEP] as EOS token
        }
        self.tokenizer.add_special_tokens(special_tokens)
        

    
    def preprocess_dataset(self, dataset, max_examples=None):
        """Filter, tokenize, and remove empty results."""
        processed_data = {
            'input_ids': [],
            'attention_mask': []
        }
        chunk_size = 1000
        current_chunk = []
        processed_count = 0

        # Process the streaming dataset
        for item in dataset:
            if max_examples and processed_count >= max_examples:
                break
                
            text = item
            if not isinstance(text, (str, bytes)) or not text.strip() or text.strip().startswith('='):
                continue

            if len(text.split())>3:
                current_chunk.append(text)
            
            if len(current_chunk) >= chunk_size :

                # Tokenize chunk
                tokenized = self.tokenizer(
                    current_chunk,
                    padding='max_length',
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors=None
                )
                
                processed_data['input_ids'].extend(tokenized['input_ids'])
                processed_data['attention_mask'].extend(tokenized['attention_mask'])
                
                processed_count += len(current_chunk)
                current_chunk = []
                
                if processed_count % (chunk_size * 10) == 0:
                    print(f"Processed {processed_count} examples")
            
        # Process any remaining examples
        if current_chunk:
            tokenized = self.tokenizer(
                current_chunk,
                padding='max_length',
                truncation=True,
                max_length=self.max_length,
                return_tensors=None
            )
            
            processed_data['input_ids'].extend(tokenized['input_ids'])
            processed_data['attention_mask'].extend(tokenized['attention_mask'])
            
        print(f"Final processed count: {len(processed_data['input_ids'])} examples")
        return processed_data

    def setup(self, stage: Optional[str] = None):
        
        json_path = "/home/user/Paper2/coco_dataset/annotations/captions_val2017.json"
        with open(json_path) as f:
            data = json.load(f)

        dataset = []
        for f in data["annotations"]:
            dataset.append(f["caption"])
        
        json_path = "/home/user/Paper2/coco_dataset/annotations/captions_train2017.json"
        with open(json_path) as f:
            data0 = json.load(f)

        train_dataset = []
        for f in data["annotations"]:
            train_dataset.append(f["caption"])
        # Optional: limit the number of examples for testing
        max_examples = {
            'train': 300000, #300000,#30000 for fientuning with lora  # Set to None for full dataset, or number for subset
            'validation': 10000,
            'test': 1000
        }
        
        if stage == 'fit' or stage is None:
            print("Processing training data...")
            train_processed = self.preprocess_dataset(
                train_dataset, 
                max_examples=max_examples['train']
            )
            self.train_dataset = WikiTextDataset(train_processed, self.tokenizer, self.max_length)
            
            print("Processing validation data...")
            val_processed = self.preprocess_dataset(
                dataset, 
                max_examples=max_examples['validation']
            )
            self.val_dataset = WikiTextDataset(val_processed, self.tokenizer, self.max_length)

        
        if stage == 'test' or stage is None:
            print("Processing test data...")
            test_processed = self.preprocess_dataset(
                dataset, 
                max_examples=max_examples['test']
            )
            self.test_dataset = WikiTextDataset(test_processed, self.tokenizer, self.max_length)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=2,
            drop_last=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=2,
            drop_last=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=False,
            prefetch_factor=None
        )

    @property
    def vocab_size(self):
        return len(self.tokenizer)

class CocoVal2017Dataset(Dataset):
    def __init__(self, data, tokenizer, max_length=30):
        self.input_ids = data['input_ids']
        self.attention_mask = data['attention_mask']
        self.max_length = max_length
        self.tokenizer = tokenizer
        self.dataset=[]
            
    def create_augmentation_mask(self, attention_mask):
        # Create mask of same shape as attention_mask
        device = attention_mask.device
        aug_mask = (torch.rand(attention_mask.shape, device=device) > 0.1)
        # Make sure we keep padding tokens masked
        aug_mask = aug_mask & attention_mask
        return aug_mask
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        # Handle both single item and batch (list) indexing
        if isinstance(idx, list):
            return self._get_batch(idx)
        
        # Single item indexing
        input_ids = torch.tensor(self.input_ids[idx], dtype=torch.long)
        attention_mask = torch.tensor(self.attention_mask[idx], dtype=torch.bool)
        
        # Ensure proper shapes
        if len(input_ids.shape) == 1:
            input_ids = input_ids[:self.max_length]
            attention_mask = attention_mask[:self.max_length]
            
            # Pad if necessary
            if input_ids.shape[0] < self.max_length:
                pad_length = self.max_length - input_ids.shape[0]
                input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.tokenizer.pad_token_id)
                attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=False)
        
        # Create augmented version
        aug_attention_mask = self.create_augmentation_mask(attention_mask)
        aug_input_ids = input_ids.clone()
        aug_input_ids = torch.where(aug_attention_mask, aug_input_ids, self.tokenizer.mask_token_id)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'target_ids': input_ids.clone(),
            'aug_ids': aug_input_ids,
            'aug_attention_mask': aug_attention_mask
        }
    
    def _get_batch(self, indices):
        """Handle batch indexing directly."""
        batch_input_ids = []
        batch_attention_mask = []
        
        for idx in indices:
            batch_input_ids.append(self.input_ids[idx])
            batch_attention_mask.append(self.attention_mask[idx])
        
        # Convert to tensors with padding
        batch_input_ids = [torch.tensor(ids, dtype=torch.long)[:self.max_length] for ids in batch_input_ids]
        batch_attention_mask = [torch.tensor(mask, dtype=torch.bool)[:self.max_length] for mask in batch_attention_mask]
        
        # Pad if necessary
        batch_input_ids = [
            torch.nn.functional.pad(ids, (0, self.max_length - ids.shape[0]), value=self.tokenizer.pad_token_id)
            if ids.shape[0] < self.max_length else ids
            for ids in batch_input_ids
        ]
        batch_attention_mask = [
            torch.nn.functional.pad(mask, (0, self.max_length - mask.shape[0]), value=False)
            if mask.shape[0] < self.max_length else mask
            for mask in batch_attention_mask
        ]
        
        # Stack into tensors
        batch_input_ids = torch.stack(batch_input_ids)
        batch_attention_mask = torch.stack(batch_attention_mask)
        
        # Create augmented versions
        batch_aug_attention_mask = self.create_augmentation_mask(batch_attention_mask)
        batch_aug_input_ids = batch_input_ids.clone()
        batch_aug_input_ids = torch.where(batch_aug_attention_mask, batch_aug_input_ids, self.tokenizer.mask_token_id)
        
        return {
            'input_ids': batch_input_ids,
            'attention_mask': batch_attention_mask,
            'target_ids': batch_input_ids.clone(),
            'aug_ids': batch_aug_input_ids,
            'aug_attention_mask': batch_aug_attention_mask
        }

class WikiTextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=30):
        self.input_ids = data['input_ids']
        self.attention_mask = data['attention_mask']
        self.max_length = max_length
        self.tokenizer = tokenizer
        
    def create_augmentation_mask(self, attention_mask):
        # Create mask of same shape as attention_mask
        device = attention_mask.device
        aug_mask = (torch.rand(attention_mask.shape, device=device) > 0.1)
        # Make sure we keep padding tokens masked
        aug_mask = aug_mask & attention_mask
        return aug_mask
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        # Handle both single item and batch (list) indexing
        if isinstance(idx, list):
            return self._get_batch(idx)
        
        # Single item indexing
        input_ids = torch.tensor(self.input_ids[idx], dtype=torch.long)
        attention_mask = torch.tensor(self.attention_mask[idx], dtype=torch.bool)
        
        # Ensure proper shapes
        if len(input_ids.shape) == 1:
            input_ids = input_ids[:self.max_length]
            attention_mask = attention_mask[:self.max_length]
            
            # Pad if necessary
            if input_ids.shape[0] < self.max_length:
                pad_length = self.max_length - input_ids.shape[0]
                input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.tokenizer.pad_token_id)
                attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=False)
        
        # Create augmented version
        aug_attention_mask = self.create_augmentation_mask(attention_mask)
        aug_input_ids = input_ids.clone()
        aug_input_ids = torch.where(aug_attention_mask, aug_input_ids, self.tokenizer.mask_token_id)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'target_ids': input_ids.clone(),
            'aug_ids': aug_input_ids,
            'aug_attention_mask': aug_attention_mask
        }
    
    def _get_batch(self, indices):
        """Handle batch indexing directly."""
        batch_input_ids = []
        batch_attention_mask = []
        
        for idx in indices:
            batch_input_ids.append(self.input_ids[idx])
            batch_attention_mask.append(self.attention_mask[idx])
        
        # Convert to tensors with padding
        batch_input_ids = [torch.tensor(ids, dtype=torch.long)[:self.max_length] for ids in batch_input_ids]
        batch_attention_mask = [torch.tensor(mask, dtype=torch.bool)[:self.max_length] for mask in batch_attention_mask]
        
        # Pad if necessary
        batch_input_ids = [
            torch.nn.functional.pad(ids, (0, self.max_length - ids.shape[0]), value=self.tokenizer.pad_token_id)
            if ids.shape[0] < self.max_length else ids
            for ids in batch_input_ids
        ]
        batch_attention_mask = [
            torch.nn.functional.pad(mask, (0, self.max_length - mask.shape[0]), value=False)
            if mask.shape[0] < self.max_length else mask
            for mask in batch_attention_mask
        ]
        
        # Stack into tensors
        batch_input_ids = torch.stack(batch_input_ids)
        batch_attention_mask = torch.stack(batch_attention_mask)
        
        # Create augmented versions
        batch_aug_attention_mask = self.create_augmentation_mask(batch_attention_mask)
        batch_aug_input_ids = batch_input_ids.clone()
        batch_aug_input_ids = torch.where(batch_aug_attention_mask, batch_aug_input_ids, self.tokenizer.mask_token_id)
        
        return {
            'input_ids': batch_input_ids,
            'attention_mask': batch_attention_mask,
            'target_ids': batch_input_ids.clone(),
            'aug_ids': batch_aug_input_ids,
            'aug_attention_mask': batch_aug_attention_mask
        }

class WikiTextDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int = 32,
        max_length: int = 30,
        num_workers: int = 8,
        cache_dir: str = '/home/user/datasets/wiki_cache'
    ):
        super().__init__()
        self.batch_size = batch_size
        self.max_length = max_length
        self.num_workers = num_workers
        self.cache_dir = cache_dir
        
        # Load ModernBERT tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            "answerdotai/ModernBERT-base",
            cache_dir=cache_dir
        )
        
        # Add special tokens
        special_tokens = {
            'pad_token': '[PAD]',
            'unk_token': '[UNK]',
            'bos_token': '[CLS]',  # Use [CLS] as BOS token
            'eos_token': '[SEP]'   # Use [SEP] as EOS token
        }
        self.tokenizer.add_special_tokens(special_tokens)
        

    
    def preprocess_dataset(self, dataset, max_examples=None):
        """Filter, tokenize, and remove empty results."""
        processed_data = {
            'input_ids': [],
            'attention_mask': []
        }
        print("not LOWERING TEXT, was not doing this before")
        chunk_size = 1000
        current_chunk = []
        processed_count = 0

        # Process the streaming dataset
        for item in dataset:
            if max_examples and processed_count >= max_examples:
                break
                
            text = item['text']
            if not isinstance(text, (str, bytes)) or not text.strip() or text.strip().startswith('='):
                continue
#            text = text.lower()
            if len(text.split())>3:
                current_chunk.append(text)
                        
            if len(current_chunk) >= chunk_size :

                # Tokenize chunk
                tokenized = self.tokenizer(
                    current_chunk,
                    padding='max_length',
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors=None
                )
                
                processed_data['input_ids'].extend(tokenized['input_ids'])
                processed_data['attention_mask'].extend(tokenized['attention_mask'])
                
                processed_count += len(current_chunk)
                current_chunk = []
                
                if processed_count % (chunk_size * 10) == 0:
                    print(f"Processed {processed_count} examples")
            
        # Process any remaining examples
        if current_chunk:
            tokenized = self.tokenizer(
                current_chunk,
                padding='max_length',
                truncation=True,
                max_length=self.max_length,
                return_tensors=None
            )
            
            processed_data['input_ids'].extend(tokenized['input_ids'])
            processed_data['attention_mask'].extend(tokenized['attention_mask'])
            
        print(f"Final processed count: {len(processed_data['input_ids'])} examples")
        return processed_data

    def setup(self, stage: Optional[str] = None):
        dataset = load_dataset(
            "wikitext",
            "wikitext-103-v1",
            cache_dir=self.cache_dir,
            streaming=True
        )
        
        # Optional: limit the number of examples for testing
        max_examples = {
            'train': 300000, #300000,#30000 for fientuning with lora  # Set to None for full dataset, or number for subset
            'validation': None,
            'test': None
        }
        
        if stage == 'fit' or stage is None:
            print("Processing training data...")
            train_processed = self.preprocess_dataset(
                dataset['train'], 
                max_examples=max_examples['train']
            )
            self.train_dataset = WikiTextDataset(train_processed, self.tokenizer, self.max_length)
            
            print("Processing validation data...")
            val_processed = self.preprocess_dataset(
                dataset['validation'], 
                max_examples=max_examples['validation']
            )
            self.val_dataset = WikiTextDataset(val_processed, self.tokenizer, self.max_length)
        
        if stage == 'test' or stage is None:
            print("Processing test data...")
            test_processed = self.preprocess_dataset(
                dataset['test'], 
                max_examples=max_examples['test']
            )
            self.test_dataset = WikiTextDataset(test_processed, self.tokenizer, self.max_length)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=2,
            drop_last=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=2,
            drop_last=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=False,
            prefetch_factor=None
        )

    @property
    def vocab_size(self):
        return len(self.tokenizer)

class WikimediaDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int = 32,
        max_length: int = 40,
        num_workers: Optional[int] = 16,
        augmentation_prob: float = 0.0,
        cache_dir: Optional[str] = None,
        tokenizer_name: str = "answerdotai/ModernBERT-base",
        max_train_samples: Optional[int] = None,
        max_val_samples: Optional[int] = 10000,
        max_test_samples: Optional[int] = 10000,
        seed: int = 42,
        trust_remote_code: bool = True,
        subset_name: str = "20231101.en",  # English subset from the screenshot
        processing_batch_size: int = 10000,  # Larger batch size for processing
        use_custom_tokenization: bool = True  # Use custom tokenization to avoid pickling issues
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Initialize tokenizer first to avoid pickling issues
        self.tokenizer = self._load_tokenizer()
        self._add_special_tokens()
        
        # Store the processed datasets
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        
        # Track whether dataset has been processed
        self.datasets_processed = False

    def _load_tokenizer(self):
        print(f"Loading tokenizer '{self.hparams.tokenizer_name}' from Hugging Face Hub.")
        return AutoTokenizer.from_pretrained(
            self.hparams.tokenizer_name,
            cache_dir=self.hparams.cache_dir,
            trust_remote_code=self.hparams.trust_remote_code
        )

    def _add_special_tokens(self):
        special_tokens_dict = {}
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                print(f"Tokenizer has no PAD token, using EOS token ({self.tokenizer.eos_token}) as PAD.")
                special_tokens_dict["pad_token"] = self.tokenizer.eos_token
            else:
                print("Tokenizer has no PAD or EOS token. Adding '[PAD]'.")
                special_tokens_dict["pad_token"] = '[PAD]'
        if self.tokenizer.mask_token is None:
            print("Tokenizer has no MASK token. Adding '[MASK]'.")
            special_tokens_dict["mask_token"] = '[MASK]'
        if special_tokens_dict:
            num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
            print(f"Added {num_added} special tokens.")

    def prepare_data(self):
        # This runs once globally, not per process
        pass

    def _fast_tokenize_and_augment(self, texts, max_length, augmentation_prob):
        """
        Process a batch of texts without using multiprocessing.
        This avoids pickling issues while still being efficient.
        """
        # Tokenize in a single batch
        tokenized = self.tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="pt"  # Return PyTorch tensors directly
        )
        
        # Get tensors
        input_ids = tokenized['input_ids']
        attention_mask = tokenized['attention_mask']
        
        # Create target_ids (same as input_ids)
        target_ids = input_ids.clone()
        
        # Create augmented version with masking
        device = input_ids.device
        mask_token_id = self.tokenizer.mask_token_id
        pad_token_id = self.tokenizer.pad_token_id
        
        # Create random mask for augmentation
        keep_prob = 1.0 - augmentation_prob
        rand_mask = torch.rand(input_ids.shape, device=device)
        aug_mask = (rand_mask > augmentation_prob) & attention_mask.bool()
        
        # Apply masking
        aug_input_ids = torch.where(
            aug_mask,
            input_ids,
            torch.tensor(mask_token_id, device=device)
        )
        # Ensure padding tokens remain as padding
        aug_input_ids = torch.where(
            attention_mask.bool(),
            aug_input_ids,
            torch.tensor(pad_token_id, device=device)
        )
        
        # Convert to lists for dataset creation
        result = {
            'input_ids': input_ids.tolist(),
            'attention_mask': attention_mask.tolist(),
            'target_ids': target_ids.tolist(),
            'aug_ids': aug_input_ids.tolist(),
            'aug_attention_mask': attention_mask.tolist()
        }
        
        return result

    def process_data(self):
        """Process dataset and create train/val/test splits."""
        print("Processing Wikimedia dataset...")
        
        if self.datasets_processed:
            print("Datasets already processed, skipping.")
            return
            
        try:
            print(f"Loading Wikimedia Wikipedia English subset: {self.hparams.subset_name}")
            # Load the dataset
            dataset_dict = load_dataset(
                "wikimedia/wikipedia", 
                name=self.hparams.subset_name,
                cache_dir=self.hparams.cache_dir,
                trust_remote_code=self.hparams.trust_remote_code
            )
            
            # Find the appropriate split
            if "train" in dataset_dict:
                full_dataset = dataset_dict["train"]
            else:
                available_splits = list(dataset_dict.keys())
                if len(available_splits) > 0:
                    full_dataset = dataset_dict[available_splits[0]]
                    print(f"Using split '{available_splits[0]}' as dataset")
                else:
                    raise ValueError(f"No valid splits found in dataset")
            
            # Apply max_train_samples limit if specified
            if self.hparams.max_train_samples is not None:
                full_dataset = full_dataset.select(range(min(len(full_dataset), self.hparams.max_train_samples)))
                print(f"Limited dataset to {len(full_dataset)} examples")
            
            # Shuffle and split the dataset
            print("Shuffling and splitting dataset...")
            full_dataset = full_dataset.shuffle(seed=self.hparams.seed)
            
            # Determine split sizes
            total_size = len(full_dataset)
            val_size = min(int(total_size * 0.1), self.hparams.max_val_samples or 10000)
            test_size = min(int(total_size * 0.1), self.hparams.max_test_samples or 10000)
            
            # Create splits
            splits = full_dataset.train_test_split(test_size=val_size + test_size, seed=self.hparams.seed)
            train_dataset = splits['train']
            remaining = splits['test'].train_test_split(test_size=test_size / (val_size + test_size), seed=self.hparams.seed)
            val_dataset = remaining['train']
            test_dataset = remaining['test']
            
            print(f"Split sizes: Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
            
            # Custom tokenization to avoid pickling issues
            if self.hparams.use_custom_tokenization:
                print(f"Using custom fast tokenization with batch size {self.hparams.processing_batch_size}")
                
                # Process train dataset
                print("Processing training split...")
                train_processed = self._process_dataset_in_batches(
                    train_dataset, 
                    batch_size=self.hparams.processing_batch_size
                )
                
                # Process validation dataset
                print("Processing validation split...")
                val_processed = self._process_dataset_in_batches(
                    val_dataset, 
                    batch_size=self.hparams.processing_batch_size
                )
                
                # Process test dataset
                print("Processing test split...")
                test_processed = self._process_dataset_in_batches(
                    test_dataset, 
                    batch_size=self.hparams.processing_batch_size
                )
                
                # Create PyTorch datasets
                from datasets import Dataset
                self.train_dataset = Dataset.from_dict(train_processed).with_format("torch")
                self.val_dataset = Dataset.from_dict(val_processed).with_format("torch")
                self.test_dataset = Dataset.from_dict(test_processed).with_format("torch")
                
            else:
                # Original HF dataset processing
                # Use single processor to avoid pickling issues (slower)
                map_kwargs = {
                    "batched": True,
                    "batch_size": 100,
                    "num_proc": 1, 
                }
                
                # Define processing functions
                def tokenize_function(examples):
                    tokenized = self.tokenizer(
                        examples["text"],
                        padding="max_length",
                        truncation=True,
                        max_length=self.hparams.max_length,
                        return_tensors=None
                    )
                    tokenized["target_ids"] = tokenized["input_ids"][:]
                    return tokenized

                def augment_function(examples):
                    input_ids = examples['input_ids']
                    attention_mask = examples['attention_mask']
                    
                    aug_ids = []
                    aug_attention_mask = []
                    
                    mask_token_id = self.tokenizer.mask_token_id
                    pad_token_id = self.tokenizer.pad_token_id
                    keep_prob = 1.0 - self.hparams.augmentation_prob
                    
                    for ids, mask in zip(input_ids, attention_mask):
                        ids_tensor = torch.tensor(ids)
                        mask_tensor = torch.tensor(mask).bool()
                        
                        rand = torch.rand(ids_tensor.shape)
                        aug_mask = (rand > self.hparams.augmentation_prob) & mask_tensor
                        
                        aug_ids_tensor = torch.where(
                            aug_mask,
                            ids_tensor,
                            torch.tensor(mask_token_id)
                        )
                        
                        aug_ids_tensor = torch.where(
                            mask_tensor,
                            aug_ids_tensor,
                            torch.tensor(pad_token_id)
                        )
                        
                        aug_ids.append(aug_ids_tensor.tolist())
                        aug_attention_mask.append(mask_tensor.long().tolist())
                    
                    examples['aug_ids'] = aug_ids
                    examples['aug_attention_mask'] = aug_attention_mask
                    return examples
                
                # Process each split
                print("Processing training split...")
                train_dataset = train_dataset.map(tokenize_function, **map_kwargs)
                train_dataset = train_dataset.map(augment_function, **map_kwargs)
                
                print("Processing validation split...")
                val_dataset = val_dataset.map(tokenize_function, **map_kwargs)
                val_dataset = val_dataset.map(augment_function, **map_kwargs)
                
                print("Processing test split...")
                test_dataset = test_dataset.map(tokenize_function, **map_kwargs)
                test_dataset = test_dataset.map(augment_function, **map_kwargs)
                
                # Keep only necessary columns
                columns_to_keep = ['input_ids', 'attention_mask', 'target_ids', 'aug_ids', 'aug_attention_mask']
                train_dataset = train_dataset.remove_columns([c for c in train_dataset.column_names if c not in columns_to_keep])
                val_dataset = val_dataset.remove_columns([c for c in val_dataset.column_names if c not in columns_to_keep])
                test_dataset = test_dataset.remove_columns([c for c in test_dataset.column_names if c not in columns_to_keep])
                
                # Convert to PyTorch format
                self.train_dataset = train_dataset.with_format("torch")
                self.val_dataset = val_dataset.with_format("torch")
                self.test_dataset = test_dataset.with_format("torch")
            
            self.datasets_processed = True
            print("Data processing complete.")
            print(f"  Train dataset: {len(self.train_dataset)} examples")
            print(f"  Validation dataset: {len(self.val_dataset)} examples")
            print(f"  Test dataset: {len(self.test_dataset)} examples")
            
        except Exception as e:
            print(f"Error processing dataset: {e}")
            import traceback
            traceback.print_exc()
            raise

    def _process_dataset_in_batches(self, dataset, batch_size=1000):
        """Process a dataset in batches to avoid memory issues."""
        total = len(dataset)
        result = {
            'input_ids': [],
            'attention_mask': [],
            'target_ids': [],
            'aug_ids': [],
            'aug_attention_mask': []
        }
        
        # Process in batches
        for i in range(0, total, batch_size):
            end = min(i + batch_size, total)
            print(f"Processing batch {i//batch_size + 1}/{(total + batch_size - 1)//batch_size}: examples {i}-{end-1}")
            
            # Get batch of texts
            batch = dataset.select(range(i, end))
            texts = batch["text"]
            
            # Process batch
            processed = self._fast_tokenize_and_augment(
                texts, 
                max_length=self.hparams.max_length,
                augmentation_prob=self.hparams.augmentation_prob
            )
            
            # Append results
            for key in result:
                result[key].extend(processed[key])
                
        return result

    def setup(self, stage: Optional[str] = None):
        print(f"Setting up data for stage: {stage}")
        # Process data if not already processed
        self.process_data()

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers or 4,  # Use some workers for loading
            pin_memory=torch.cuda.is_available(),
            persistent_workers=True if (self.hparams.num_workers or 0) > 0 else False,
            shuffle=True,
            prefetch_factor=2 if (self.hparams.num_workers or 0) > 0 else None,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers or 4,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=True if (self.hparams.num_workers or 0) > 0 else False,
            shuffle=False,
            prefetch_factor=2 if (self.hparams.num_workers or 0) > 0 else None,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers or 4,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=True if (self.hparams.num_workers or 0) > 0 else False,
            shuffle=False,
            prefetch_factor=2 if (self.hparams.num_workers or 0) > 0 else None,
        )

    @property
    def vocab_size(self):
        return len(self.tokenizer)


class WikimediaRobustFineTuneDataModule(pl.LightningDataModule):
    """
    Specialized data module for robust finetuning of a previously trained model.
    
    Key features:
    1. Uses fixed validation and test sets for consistent evaluation
    2. Ensures no overlap between train/val/test sets
    3. Efficiently loads only the required number of training samples
    4. Supports checkpointing of split assignments to ensure consistency across runs
    """
    
    def __init__(
        self,
        batch_size: int = 32,
        max_length: int = 40,
        num_workers: Optional[int] = 16,
        augmentation_prob: float = 0.0,  # Increased for robust finetuning
        cache_dir: Optional[str] = None,
        tokenizer_name: str = "answerdotai/ModernBERT-base",
        finetune_samples: int = 100000,  # Default for robust finetuning
        max_val_samples: int = 10000,
        max_test_samples: int = 10000,
        seed: int = 42,
        trust_remote_code: bool = True,
        subset_name: str = "20231101.en",
        processing_batch_size: int = 5000,
        use_custom_tokenization: bool = True,
        split_cache_file: Optional[str] = "wikimedia_robust_splits.json",
        # Advanced options
        streaming: bool = False,
        fixed_val_test: bool = True,  # Use fixed validation and test sets
        buffer_size: int = 10000,
        dataset_path: Optional[str] = None,  # Allow using a previously saved dataset
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['tokenizer'])
        
        # Initialize tokenizer first to avoid pickling issues
        self.tokenizer = self._load_tokenizer()
        self._add_special_tokens()
        
        # Store dataset splits
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        
        # For tracking state
        self.datasets_processed = False
        self.split_indices = None
        
    def _load_tokenizer(self):
        print(f"Loading tokenizer '{self.hparams.tokenizer_name}' from Hugging Face Hub...")
        return AutoTokenizer.from_pretrained(
            self.hparams.tokenizer_name,
            cache_dir=self.hparams.cache_dir,
            trust_remote_code=self.hparams.trust_remote_code
        )

    def _add_special_tokens(self):
        special_tokens_dict = {}
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                print(f"Tokenizer has no PAD token, using EOS token ({self.tokenizer.eos_token}) as PAD.")
                special_tokens_dict["pad_token"] = self.tokenizer.eos_token
            else:
                print("Tokenizer has no PAD or EOS token. Adding '[PAD]'.")
                special_tokens_dict["pad_token"] = '[PAD]'
        if self.tokenizer.mask_token is None:
            print("Tokenizer has no MASK token. Adding '[MASK]'.")
            special_tokens_dict["mask_token"] = '[MASK]'
        if special_tokens_dict:
            num_added = self.tokenizer.add_special_tokens(special_tokens_dict)
            print(f"Added {num_added} special tokens.")
            
    def _compute_example_hash(self, example):
        """Compute a stable hash for an example to use for deterministic splits."""
        # Use the text content for hashing
        text = example.get("text", "")
        # Create a hash of the text
        hash_obj = hashlib.md5(text.encode('utf-8'))
        # Return as an integer for easier modulo operations
        return int(hash_obj.hexdigest(), 16)
    
    def _get_split_from_hash(self, example_hash, train_threshold=80, val_threshold=90):
        """
        Deterministically assign an example to a split based on its hash.
        
        Args:
            example_hash: The hash value for the example
            train_threshold: Percentile cutoff for training set (default: 80%)
            val_threshold: Percentile cutoff for validation set (default: 90%)
                           Examples beyond this go to test set
        
        Returns:
            str: 'train', 'validation', or 'test'
        """
        bucket = example_hash % 100  # Get a value 0-99
        
        if bucket < train_threshold:
            return 'train'
        elif bucket < val_threshold:
            return 'validation'
        else:
            return 'test'
    
    def _load_or_create_split_indices(self):
        """Load existing split indices or create new ones."""
        if self.split_indices is not None:
            return self.split_indices
            
        if self.hparams.split_cache_file and os.path.exists(self.hparams.split_cache_file):
            print(f"Loading split assignments from {self.hparams.split_cache_file}")
            with open(self.hparams.split_cache_file, 'r') as f:
                self.split_indices = json.load(f)
            return self.split_indices
            
        print("No cached split assignments found. Creating new split assignments...")
        self.split_indices = {
            'train': set(),
            'validation': set(),
            'test': set(),
            'metadata': {
                'finetune_samples': self.hparams.finetune_samples,
                'max_val_samples': self.hparams.max_val_samples,
                'max_test_samples': self.hparams.max_test_samples,
                'seed': self.hparams.seed,
                'subset_name': self.hparams.subset_name
            }
        }
        return self.split_indices
        
    def _save_split_indices(self):
        """Save split indices to cache file."""
        if not self.hparams.split_cache_file:
            return
            
        # Convert sets to lists for JSON serialization
        json_safe_indices = {
            'train': list(self.split_indices['train']),
            'validation': list(self.split_indices['validation']),
            'test': list(self.split_indices['test']),
            'metadata': self.split_indices['metadata']
        }
        
        print(f"Saving split assignments to {self.hparams.split_cache_file}")
        with open(self.hparams.split_cache_file, 'w') as f:
            json.dump(json_safe_indices, f)
            
    def prepare_data(self):
        """
        Prepare data globally (runs once per node).
        In distributed training, this is a good place to download the dataset.
        """
        if self.hparams.dataset_path is None:
            # Download the dataset first to make it available to all processes
            load_dataset(
                "wikimedia/wikipedia", 
                name=self.hparams.subset_name,
                cache_dir=self.hparams.cache_dir,
                trust_remote_code=self.hparams.trust_remote_code,
                streaming=False,  # Just for downloading, not full processing
                split="train[:100]"  # Just load a small sample to trigger download
            )
    
    def _generate_diverse_finetune_samples(self, dataset_iter, split_indices):
        """
        Generate a diverse set of examples for finetuning by using the hash
        to select a well-distributed subset.
        """
        print("Selecting diverse examples for robust finetuning...")
        
        # Track statistics to confirm diversity
        length_distribution = {
            'short': 0,   # <50 words
            'medium': 0,  # 50-200 words
            'long': 0     # >200 words
        }
        
        # For efficient streaming, we'll collect examples until we have enough
        train_examples = []
        val_examples = []
        test_examples = []
        
        # Counters for each split
        train_count = 0
        val_count = 0
        test_count = 0
        
        # Target counts (adding buffer for filtering/processing)
        train_target = self.hparams.finetune_samples * 1.1  # 10% buffer
        val_target = self.hparams.max_val_samples * 1.1
        test_target = self.hparams.max_test_samples * 1.1
        
        # Determine if we're using cached split assignments or generating new ones
        using_cached_splits = len(split_indices['train']) > 0 or len(split_indices['validation']) > 0
        
        # Process examples from dataset iterator
        for example_idx, example in enumerate(dataset_iter):
            # Compute hash for consistent assignment
            example_hash = self._compute_example_hash(example)
            
            # Determine split assignment
            if using_cached_splits:
                # Using cached assignments
                if str(example_idx) in split_indices['train']:
                    split = 'train'
                elif str(example_idx) in split_indices['validation']:
                    split = 'validation'
                elif str(example_idx) in split_indices['test']:
                    split = 'test'
                else:
                    # Skip examples not in any cached set
                    continue
            else:
                # Generate new split assignment
                split = self._get_split_from_hash(example_hash)
                
                # Store assignment for future use
                if split == 'train':
                    split_indices['train'].add(str(example_idx))
                elif split == 'validation':
                    split_indices['validation'].add(str(example_idx))
                elif split == 'test':
                    split_indices['test'].add(str(example_idx))
            
            # Update diversity stats if accepting the example
            text = example.get("text", "")
            word_count = len(text.split())
            
            if word_count < 50:
                length_distribution['short'] += 1
            elif word_count < 200:
                length_distribution['medium'] += 1
            else:
                length_distribution['long'] += 1
                
            # Add to appropriate split collection if we need more examples
            if split == 'train' and train_count < train_target:
                train_examples.append(example)
                train_count += 1
            elif split == 'validation' and val_count < val_target:
                val_examples.append(example)
                val_count += 1
            elif split == 'test' and test_count < test_target:
                test_examples.append(example)
                test_count += 1
                
            # Check if we have enough examples
            if (train_count >= train_target and 
                val_count >= val_target and 
                test_count >= test_target):
                break
                
            # Progress reporting
            if example_idx % 10000 == 0:
                print(f"Processed {example_idx} examples. "
                      f"Train: {train_count}/{train_target}, "
                      f"Val: {val_count}/{val_target}, "
                      f"Test: {test_count}/{test_target}")
                
        # Report distribution statistics
        print("\nLength distribution of selected examples:")
        total = sum(length_distribution.values())
        for category, count in length_distribution.items():
            percentage = (count / total) * 100 if total > 0 else 0
            print(f"  {category}: {count} examples ({percentage:.1f}%)")
            
        # Create datasets from collected examples
        from datasets import Dataset
        
        train_dataset = Dataset.from_dict({
            k: [example.get(k) for example in train_examples] 
            for k in train_examples[0].keys()
        }) if train_examples else Dataset.from_dict({"text": []})
        
        val_dataset = Dataset.from_dict({
            k: [example.get(k) for example in val_examples]
            for k in val_examples[0].keys()
        }) if val_examples else Dataset.from_dict({"text": []})
        
        test_dataset = Dataset.from_dict({
            k: [example.get(k) for example in test_examples]
            for k in test_examples[0].keys()
        }) if test_examples else Dataset.from_dict({"text": []})
        
        # Final limiting of datasets to exactly the required sizes
        if len(train_dataset) > self.hparams.finetune_samples:
            train_dataset = train_dataset.select(range(self.hparams.finetune_samples))
        if len(val_dataset) > self.hparams.max_val_samples:
            val_dataset = val_dataset.select(range(self.hparams.max_val_samples))
        if len(test_dataset) > self.hparams.max_test_samples:
            test_dataset = test_dataset.select(range(self.hparams.max_test_samples))
            
        print(f"Final dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
        
        return train_dataset, val_dataset, test_dataset
    
    def process_data(self):
        """Process data efficiently for robust finetuning."""
        if self.datasets_processed:
            print("Datasets already processed, skipping.")
            return
            
        # Load or create split indices
        split_indices = self._load_or_create_split_indices()
        
        try:
            # If using a previously saved dataset
            if self.hparams.dataset_path and os.path.exists(self.hparams.dataset_path):
                print(f"Loading preprocessed dataset from {self.hparams.dataset_path}")
                dataset_dict = torch.load(self.hparams.dataset_path)
                self.train_dataset = dataset_dict.get('train')
                self.val_dataset = dataset_dict.get('val')
                self.test_dataset = dataset_dict.get('test')
                
                print(f"Loaded datasets - Train: {len(self.train_dataset)}, Val: {len(self.val_dataset)}, Test: {len(self.test_dataset)}")
                self.datasets_processed = True
                return
            
            print(f"Loading Wikimedia Wikipedia subset: {self.hparams.subset_name}")
            
            # Load dataset with streaming for memory efficiency
            dataset_dict = load_dataset(
                "wikimedia/wikipedia", 
                name=self.hparams.subset_name,
                cache_dir=self.hparams.cache_dir,
                trust_remote_code=self.hparams.trust_remote_code,
                streaming=self.hparams.streaming
            )
            
            # Get the main dataset split
            if "train" in dataset_dict:
                full_dataset = dataset_dict["train"]
            else:
                available_splits = list(dataset_dict.keys())
                if len(available_splits) > 0:
                    full_dataset = dataset_dict[available_splits[0]]
                    print(f"Using split '{available_splits[0]}' as dataset")
                else:
                    raise ValueError("No valid splits found in dataset")
            
            # Generate train/val/test splits with diverse examples
            train_raw, val_raw, test_raw = self._generate_diverse_finetune_samples(
                full_dataset, split_indices
            )
            
            # Save the split indices for future use
            self._save_split_indices()
            
            # Process each split with tokenization
            print("Processing datasets with tokenization...")
            
            if self.hparams.use_custom_tokenization:
                print("Using efficient batch tokenization...")
                self.train_dataset = self._process_dataset_with_tokenization(
                    train_raw, batch_size=self.hparams.processing_batch_size
                )
                self.val_dataset = self._process_dataset_with_tokenization(
                    val_raw, batch_size=self.hparams.processing_batch_size
                )
                self.test_dataset = self._process_dataset_with_tokenization(
                    test_raw, batch_size=self.hparams.processing_batch_size
                )
            else:
                # Use HuggingFace's dataset mapping for tokenization
                map_kwargs = {
                    "batched": True,
                    "batch_size": self.hparams.processing_batch_size,
                    "num_proc": min(4, os.cpu_count() or 1),
                }
                
                # Define tokenization function
                def tokenize_function(examples):
                    tokenized = self.tokenizer(
                        examples["text"],
                        padding="max_length",
                        truncation=True,
                        max_length=self.hparams.max_length,
                        return_tensors=None
                    )
                    tokenized["target_ids"] = tokenized["input_ids"][:]
                    return tokenized
                
                # Define augmentation function
                def augment_function(examples):
                    input_ids = examples['input_ids']
                    attention_mask = examples['attention_mask']
                    
                    aug_ids = []
                    aug_attention_mask = []
                    
                    mask_token_id = self.tokenizer.mask_token_id
                    pad_token_id = self.tokenizer.pad_token_id
                    
                    for ids, mask in zip(input_ids, attention_mask):
                        ids_tensor = torch.tensor(ids)
                        mask_tensor = torch.tensor(mask).bool()
                        
                        rand = torch.rand(ids_tensor.shape)
                        aug_mask = (rand > self.hparams.augmentation_prob) & mask_tensor
                        
                        aug_ids_tensor = torch.where(
                            aug_mask,
                            ids_tensor,
                            torch.tensor(mask_token_id)
                        )
                        
                        aug_ids_tensor = torch.where(
                            mask_tensor,
                            aug_ids_tensor,
                            torch.tensor(pad_token_id)
                        )
                        
                        aug_ids.append(aug_ids_tensor.tolist())
                        aug_attention_mask.append(mask_tensor.long().tolist())
                    
                    examples['aug_ids'] = aug_ids
                    examples['aug_attention_mask'] = aug_attention_mask
                    return examples
                
                # Process each split
                print("Tokenizing training split...")
                self.train_dataset = train_raw.map(tokenize_function, **map_kwargs)
                self.train_dataset = self.train_dataset.map(augment_function, **map_kwargs)
                
                print("Tokenizing validation split...")
                self.val_dataset = val_raw.map(tokenize_function, **map_kwargs)
                self.val_dataset = self.val_dataset.map(augment_function, **map_kwargs)
                
                print("Tokenizing test split...")
                self.test_dataset = test_raw.map(tokenize_function, **map_kwargs)
                self.test_dataset = test_raw.map(augment_function, **map_kwargs)
                
                # Keep only necessary columns
                columns_to_keep = ['input_ids', 'attention_mask', 'target_ids', 'aug_ids', 'aug_attention_mask']
                
                self.train_dataset = self.train_dataset.remove_columns(
                    [c for c in self.train_dataset.column_names if c not in columns_to_keep]
                )
                self.val_dataset = self.val_dataset.remove_columns(
                    [c for c in self.val_dataset.column_names if c not in columns_to_keep]
                )
                self.test_dataset = self.test_dataset.remove_columns(
                    [c for c in self.test_dataset.column_names if c not in columns_to_keep]
                )
                
                # Convert to PyTorch format
                self.train_dataset = self.train_dataset.with_format("torch")
                self.val_dataset = self.val_dataset.with_format("torch")
                self.test_dataset = self.test_dataset.with_format("torch")
            
            # Optionally save processed dataset
            if self.hparams.dataset_path:
                print(f"Saving processed dataset to {self.hparams.dataset_path}")
                os.makedirs(os.path.dirname(self.hparams.dataset_path), exist_ok=True)
                torch.save({
                    'train': self.train_dataset,
                    'val': self.val_dataset,
                    'test': self.test_dataset
                }, self.hparams.dataset_path)
            
            self.datasets_processed = True
            print("Data processing complete.")
            print(f"  Train dataset: {len(self.train_dataset)} examples")
            print(f"  Validation dataset: {len(self.val_dataset)} examples")
            print(f"  Test dataset: {len(self.test_dataset)} examples")
            
        except Exception as e:
            print(f"Error processing dataset: {e}")
            import traceback
            traceback.print_exc()
            raise
    
    def _process_dataset_with_tokenization(self, dataset, batch_size=1000):
        """Process a dataset with tokenization in batches efficiently."""
        from datasets import Dataset
        
        # Initialize result containers
        all_input_ids = []
        all_attention_masks = []
        all_target_ids = []
        all_aug_ids = []
        all_aug_attention_masks = []
        
        # Process in batches 
        total = len(dataset)
        for i in range(0, total, batch_size):
            end = min(i + batch_size, total)
            print(f"Processing batch {i//batch_size + 1}/{(total + batch_size - 1)//batch_size}: examples {i}-{end-1}")
            
            # Get batch of texts
            batch = dataset.select(range(i, end))
            texts = batch["text"]
            
            # Tokenize in a single batch
            tokenized = self.tokenizer(
                texts,
                padding="max_length",
                truncation=True,
                max_length=self.hparams.max_length,
                return_tensors="pt"
            )
            
            # Get tensors
            input_ids = tokenized['input_ids']
            attention_mask = tokenized['attention_mask']
            
            # Create target_ids (same as input_ids)
            target_ids = input_ids.clone()
            
            # Create augmented version with masking
            device = input_ids.device
            mask_token_id = self.tokenizer.mask_token_id
            pad_token_id = self.tokenizer.pad_token_id
            
            # Random mask for standard augmentation
            keep_prob = 1.0 - self.hparams.augmentation_prob
            rand_mask = torch.rand(input_ids.shape, device=device)
            aug_mask = (rand_mask > self.hparams.augmentation_prob) & attention_mask.bool()
            
            # Apply masking
            aug_input_ids = torch.where(
                aug_mask,
                input_ids,
                torch.tensor(mask_token_id, device=device)
            )
            # Ensure padding tokens remain as padding
            aug_input_ids = torch.where(
                attention_mask.bool(),
                aug_input_ids,
                torch.tensor(pad_token_id, device=device)
            )
            
            # Append to results
            all_input_ids.extend(input_ids.tolist())
            all_attention_masks.extend(attention_mask.tolist())
            all_target_ids.extend(target_ids.tolist())
            all_aug_ids.extend(aug_input_ids.tolist())
            all_aug_attention_masks.extend(attention_mask.tolist())
        
        # Create new dataset with processed features
        processed_dataset = Dataset.from_dict({
            'input_ids': all_input_ids,
            'attention_mask': all_attention_masks,
            'target_ids': all_target_ids,
            'aug_ids': all_aug_ids,
            'aug_attention_mask': all_aug_attention_masks
        })
        
        return processed_dataset.with_format("torch")
    
    def setup(self, stage: Optional[str] = None):
        print(f"Setting up data for stage: {stage}")
        # Process data if not already processed
        self.process_data()

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers or 4,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=True if (self.hparams.num_workers or 0) > 0 else False,
            shuffle=True,
            prefetch_factor=None if (self.hparams.num_workers or 0) > 0 else None,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers or 4,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=True if (self.hparams.num_workers or 0) > 0 else False,
            shuffle=False,
            prefetch_factor=None if (self.hparams.num_workers or 0) > 0 else None,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers or 4,
            pin_memory=torch.cuda.is_available(),
            persistent_workers=True if (self.hparams.num_workers or 0) > 0 else False,
            shuffle=False,
            prefetch_factor=None if (self.hparams.num_workers or 0) > 0 else None,
        )

    @property
    def vocab_size(self):
        return len(self.tokenizer)

class CompletelyCustomDataset(torch.utils.data.Dataset):
    def __init__(
        self, 
        huggingface_dataset, 
        cls_dir, 
        tokenizer_name="answerdotai/ModernBERT-base", 
        device="cpu",
        max_length=40,
        batch_size=128,
        sample_limit=None,
        force_recompute=False
    ):
        """A completely custom dataset that extracts all needed data from HF dataset."""
        print("Initializing CompleteyCustomDataset...")
        self.cls_dir = cls_dir
        self.tokenizer_name = tokenizer_name
        self.device = device
        self.max_length = max_length
        self.batch_size = batch_size
        self.force_recompute = force_recompute
        
        # Create output directory if it doesn't exist
        os.makedirs(cls_dir, exist_ok=True)
        
        # Determine dataset size
        full_size = len(huggingface_dataset)
        self._size = min(full_size, sample_limit) if sample_limit else full_size
        print(f"Dataset will contain {self._size} samples")
        
        # Load tokenizer
        from transformers import AutoTokenizer, AutoModel
        print(f"Loading tokenizer: {tokenizer_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        
        # IMPORTANT: Extract text samples once to avoid HF dataset interaction in workers
        print("Extracting text samples from HuggingFace dataset...")
        self.text_samples = []
        for i in tqdm(range(self._size), desc="Extracting texts"):
            self.text_samples.append(huggingface_dataset[i]["text"])
        print(f"Extracted {len(self.text_samples)} text samples")
        
        # Set up embeddings
        self._setup_embeddings(huggingface_dataset)
        
        # Set up tokenization
        self._setup_tokenization()
        
        print("CustomDataset initialization complete!")
    
    def _get_embedding_dimension(self, sample_text):
        """Get embedding dimension from a sample text."""
        from transformers import AutoModel
        print("Determining embedding dimension...")
        
        # Load model temporarily
        model = AutoModel.from_pretrained(self.tokenizer_name).to(self.device)
        model.eval()
        
        # Get embedding dimension
        with torch.no_grad():
            inputs = self.tokenizer(sample_text, return_tensors="pt").to(self.device)
            outputs = model(**inputs)
            cls_embedding = outputs.last_hidden_state[:, 0, :]
            embedding_dim = cls_embedding.shape[-1]
        
        # Free up GPU memory
        del model
        if self.device == "cuda":
            torch.cuda.empty_cache()
            
        return embedding_dim
        
    def _setup_embeddings(self, hf_dataset):
        """Set up memory-mapped embeddings."""
        # Get embedding dimension
        self.embedding_dim = self._get_embedding_dimension(self.text_samples[0])
        print(f"CLS embedding dimension: {self.embedding_dim}")
        
        # Setup memory-mapped file for embeddings
        self.memmap_path = os.path.join(self.cls_dir, "cls_embeddings.dat")
        self.index_file = os.path.join(self.cls_dir, "cls_index.npz")
        
        # Check if files already exist
        index_exists = os.path.exists(self.index_file)
        memmap_exists = os.path.exists(self.memmap_path)
        
        # If both files exist and no recompute, load the existing memmap
        if index_exists and memmap_exists and not self.force_recompute:
            print(f"Loading existing memory-mapped embeddings from {self.memmap_path}")
            # Load index file to verify dataset size
            index_data = np.load(self.index_file)
            stored_size = index_data['size'].item()
            stored_dim = index_data['dim'].item()
            
            if stored_size != self._size or stored_dim != self.embedding_dim:
                print(f"⚠️ Warning: Existing memmap has {stored_size} samples with dim {stored_dim}")
                print(f"Current dataset has {self._size} samples with dim {self.embedding_dim}")
                print("Forcing recomputation...")
                self.force_recompute = True
            else:
                # Open the memmap for reading
                self.embeddings = np.memmap(
                    self.memmap_path,
                    dtype='float32',
                    mode='r',
                    shape=(self._size, self.embedding_dim)
                )
                print("Memory-mapped embeddings loaded successfully")
                return
        
        # Create a new memory-mapped file
        print(f"Creating new memory-mapped file for {self._size} embeddings")
        self.embeddings = np.memmap(
            self.memmap_path,
            dtype='float32',
            mode='w+',
            shape=(self._size, self.embedding_dim)
        )
        
        # Compute the embeddings
        self._compute_embeddings(hf_dataset)
        
        # Save the index file with metadata
        np.savez(
            self.index_file,
            size=self._size,
            dim=self.embedding_dim
        )
        
        # Reopen in read mode
        self.embeddings = np.memmap(
            self.memmap_path,
            dtype='float32',
            mode='r',
            shape=(self._size, self.embedding_dim)
        )
        
        print("Memory-mapped embeddings created successfully")
    
    def _compute_embeddings(self, hf_dataset):
        """Compute CLS embeddings for all samples and store in memmap."""
        from transformers import AutoModel
        print(f"Computing CLS embeddings for {self._size} samples...")
        
        # Load model for embeddings
        print(f"Loading model: {self.tokenizer_name}")
        model = AutoModel.from_pretrained(self.tokenizer_name)
        model.to(self.device)
        model.eval()
        print(f"Model loaded and moved to {self.device}")
        
        # Process in batches with progress bar
        with torch.no_grad():
            for i in tqdm(range(0, self._size, self.batch_size), 
                         desc="Computing CLS embeddings"):
                batch_end = min(i + self.batch_size, self._size)
                batch_size = batch_end - i
                
                # Get batch texts directly from our extracted list
                batch_texts = self.text_samples[i:batch_end]
                
                # Tokenize texts
                encodings = self.tokenizer(
                    batch_texts,
                    max_length=self.max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt"
                ).to(self.device)
                
                # Get CLS embeddings
                outputs = model(**encodings)
                cls_embeddings = outputs.last_hidden_state[:, 0, :]
                
                # Store in memmap
                self.embeddings[i:batch_end] = cls_embeddings.cpu().numpy()
        
        # Flush to ensure data is written to disk
        self.embeddings.flush()
        
        # Free up memory
        del model
        if self.device == "cuda":
            torch.cuda.empty_cache()
            
        print("CLS embedding computation complete")
        
    def _setup_tokenization(self):
        """Pre-tokenize all text samples."""
        tokens_path = os.path.join(self.cls_dir, "tokenized_data.pt")
        
        if os.path.exists(tokens_path) and not self.force_recompute:
            print(f"Loading pre-tokenized data from {tokens_path}")
            self.tokenized_data = torch.load(tokens_path)
            return
            
        print(f"Pre-tokenizing {self._size} samples...")
        all_input_ids = []
        all_attention_masks = []
        
        # Process in batches for efficiency
        batch_size = 1000
        
        for i in tqdm(range(0, self._size, batch_size), desc="Tokenizing"):
            end_idx = min(i + batch_size, self._size)
            batch_texts = self.text_samples[i:end_idx]
            
            # Tokenize batch
            encodings = self.tokenizer(
                batch_texts,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
            
            # Add to storage
            all_input_ids.append(encodings["input_ids"])
            all_attention_masks.append(encodings["attention_mask"])
        
        # Concatenate batches
        self.tokenized_data = {
            "input_ids": torch.cat(all_input_ids, dim=0),
            "attention_mask": torch.cat(all_attention_masks, dim=0)
        }
        
        # Save to disk
        print(f"Saving pre-tokenized data to {tokens_path}")
        torch.save(self.tokenized_data, tokens_path)
        print("Tokenization complete and saved")
        
    def __len__(self):
        """Return the number of samples in the dataset."""
        return self._size
    
    def __getitem__(self, idx):
        """Get a sample by index."""
        # Handle tensor indices
        if isinstance(idx, torch.Tensor):
            idx = idx.item()
            
        # Validate index
        if idx >= self._size:
            raise IndexError(f"Index {idx} out of bounds for dataset of size {self._size}")
            
        # Get text from our extracted list
        text_sample = self.text_samples[idx]
        
        # Get CLS embedding from memmap
        cls_embedding = torch.from_numpy(self.embeddings[idx].copy()).float()
        
        # Get tokenized data
        input_ids = self.tokenized_data["input_ids"][idx]
        attention_mask = self.tokenized_data["attention_mask"][idx]
        
        # Return all required fields
        return {
            "text": text_sample,
            "cls_embedding": cls_embedding,
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }
class TextReconstructionDataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_path: str = None,
        tokenizer_name: str = "answerdotai/ModernBERT-base",
        batch_size: int = 32,
        max_length: int = 40,
        num_workers: Optional[int] = None,
        augmentation_prob: float = 0.0,
        cache_dir: str = None,
        cls_embedding_dir: Optional[str] = None,
        max_train_samples: Optional[int] = 100000,
        max_val_samples: Optional[int] = 10000,
        max_test_samples: Optional[int] = 10000,
        subset_name: str = "20231101.en",
        trust_remote_code: bool = True,
        device: str = "cuda",  # Set to "cuda" for GPU acceleration
        force_recompute: bool = False,  # Force recomputation of embeddings
    ):
        super().__init__()
        self.tokenizer_name = tokenizer_name
        self.batch_size = batch_size
        self.max_length = max_length
        self.num_workers = num_workers or max(1, multiprocessing.cpu_count() // 2)
        self.augmentation_prob = augmentation_prob
        self.cache_dir = cache_dir
        self.cls_embedding_dir = cls_embedding_dir
        self.max_train_samples = max_train_samples
        self.max_val_samples = max_val_samples
        self.max_test_samples = max_test_samples
        self.subset_name = subset_name
        self.trust_remote_code = trust_remote_code
        self.device = device
        self.force_recompute = force_recompute

        self.tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            cache_dir=cache_dir
        )

    def _load_dataset(self, split):
        dataset = load_dataset(
            "wikimedia/wikipedia",
            name=self.subset_name,
            cache_dir=self.cache_dir,
            trust_remote_code=self.trust_remote_code
        )["train"]

        dataset = dataset.shuffle(seed=42)
        val_end = self.max_val_samples
        test_end = val_end + self.max_test_samples

        if split == "train":
            return dataset.select(range(test_end, test_end + self.max_train_samples))
        elif split == "validation":
            return dataset.select(range(0, val_end))
        elif split == "test":
            return dataset.select(range(val_end, test_end))
        else:
            raise ValueError(f"Unknown split: {split}")

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            print("Setting up training dataset...")
            train_data = self._load_dataset('train')
            self.train_dataset = CompletelyCustomDataset(
                huggingface_dataset=train_data,
                cls_dir=os.path.join(self.cls_embedding_dir, 'train'),
                tokenizer_name=self.tokenizer_name,
                device=self.device,
                max_length=self.max_length,
                force_recompute=False,
                sample_limit=self.max_train_samples
            )
            
            print("Setting up validation dataset...")
            val_data = self._load_dataset('validation')
            self.val_dataset = CompletelyCustomDataset(
                huggingface_dataset=val_data,
                cls_dir=os.path.join(self.cls_embedding_dir, 'validation'),
                tokenizer_name=self.tokenizer_name,
                device=self.device,
                max_length=self.max_length,
                force_recompute=False,
                sample_limit=self.max_val_samples
            )
    
        if stage == 'test' or stage is None:
            print("Setting up test dataset...")
            test_data = self._load_dataset('test')
            self.test_dataset = CompletelyCustomDataset(
                huggingface_dataset=test_data,
                cls_dir=os.path.join(self.cls_embedding_dir, 'test'),
                tokenizer_name=self.tokenizer_name,
                device=self.device,
                max_length=self.max_length,
                force_recompute=False,
                sample_limit=self.max_test_samples
            )

    def train_dataloader(self):
        # Now we can use multiple workers safely
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,  # Can use multiple workers now
            pin_memory=True,
            persistent_workers=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True
        )

    @property
    def vocab_size(self):
        return len(self.tokenizer)
