"""
    wikitext_data.py
    
    Utilities for loading and processing WikiText-2 dataset for language modeling.
    Uses word-level tokenization to create a manageable vocabulary suitable for observing grokking.
"""

import torch
from torch.utils.data import Dataset, DataLoader
import os
import zipfile
from collections import Counter
import re

class WikiTextDataset(Dataset):
    """Dataset class for WikiText language modeling."""
    
    def __init__(self, tokens, seq_length=32):
        """
        Args:
            tokens: List of token indices
            seq_length: Length of each sequence for training
        """
        self.tokens = tokens
        self.seq_length = seq_length
        
    def __len__(self):
        # Number of sequences we can create from the tokens
        return max(0, len(self.tokens) - self.seq_length)
    
    def __getitem__(self, idx):
        # Get sequence of length seq_length
        input_ids = torch.tensor(self.tokens[idx:idx + self.seq_length], dtype=torch.long)
        # Labels are the next token for each position
        labels = torch.tensor(self.tokens[idx + 1:idx + self.seq_length + 1], dtype=torch.long)
        attention_mask = torch.ones_like(input_ids, dtype=torch.float)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }


class WordTokenizer:
    """Simple word-level tokenizer for WikiText."""
    
    def __init__(self, vocab_size=5000):
        self.vocab_size = vocab_size
        self.word2idx = {}
        self.idx2word = {}
        self.pad_token = '<PAD>'
        self.unk_token = '<UNK>'
        self.eos_token = '<EOS>'
        
    def build_vocab(self, texts):
        """Build vocabulary from list of texts."""
        # Tokenize all texts and count words
        word_counts = Counter()
        for text in texts:
            words = self.tokenize_text(text)
            word_counts.update(words)
        
        # Reserve special tokens
        self.word2idx = {
            self.pad_token: 0,
            self.unk_token: 1,
            self.eos_token: 2,
        }
        
        # Add most common words up to vocab_size
        most_common = word_counts.most_common(self.vocab_size - 3)
        for idx, (word, _) in enumerate(most_common, start=3):
            self.word2idx[word] = idx
        
        self.idx2word = {idx: word for word, idx in self.word2idx.items()}
        print(f"Vocabulary built with {len(self.word2idx)} tokens")
        
    def tokenize_text(self, text):
        """Convert text to list of words."""
        # Simple word tokenization: lowercase and split on whitespace/punctuation
        text = text.lower()
        # Keep words and common punctuation
        words = re.findall(r'\b\w+\b|[.,!?;]', text)
        return words
    
    def encode(self, text):
        """Convert text to list of token indices."""
        words = self.tokenize_text(text)
        return [self.word2idx.get(word, self.word2idx[self.unk_token]) for word in words]
    
    def decode(self, indices):
        """Convert list of token indices back to text."""
        words = [self.idx2word.get(idx, self.unk_token) for idx in indices]
        return ' '.join(words)


def download_wikitext2(data_dir='data/wikitext'):
    """Extract WikiText-2 dataset from local zip file if not already present."""
    os.makedirs(data_dir, exist_ok=True)
    
    # Check if already extracted
    train_path = os.path.join(data_dir, 'wikitext-2', 'wiki.train.tokens')
    if os.path.exists(train_path):
        print("WikiText-2 already extracted.")
        return data_dir
    
    # Look for local zip file in grokking_experiments folder
    # Get the project root (parent of utilities folder)
    current_dir = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.dirname(current_dir)
    local_zip_path = os.path.join(project_root, 'grokking_experiments', 'wikitext-2.zip')
    
    if not os.path.exists(local_zip_path):
        raise FileNotFoundError(
            f"WikiText-2 zip file not found at {local_zip_path}. "
            "Please place wikitext-2.zip in the grokking_experiments folder."
        )
    
    print(f"Found local WikiText-2 zip file at {local_zip_path}")
    print("Extracting...")
    with zipfile.ZipFile(local_zip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)
    
    print("WikiText-2 extracted successfully.")
    
    return data_dir


def load_wikitext_file(filepath):
    """Load a WikiText file and return as single string."""
    with open(filepath, 'r', encoding='utf-8') as f:
        text = f.read()
    
    # Remove empty lines and section headers (lines starting with =)
    lines = [line.strip() for line in text.split('\n') 
             if line.strip() and not line.strip().startswith('=')]
    
    return ' '.join(lines)


def prepare_wikitext_data(data_dir='data/wikitext', 
                          vocab_size=5000, 
                          seq_length=32,
                          train_samples=None,
                          val_samples=None,
                          test_samples=None):
    """
    Prepare WikiText-2 data for language modeling.
    
    Args:
        data_dir: Directory to store/load WikiText data
        vocab_size: Size of vocabulary
        seq_length: Length of each training sequence
        train_samples: Number of training sequences (None = use all)
        val_samples: Number of validation sequences (None = use all)
        test_samples: Number of test sequences (None = use all)
    
    Returns:
        tokenizer: Trained tokenizer
        train_tokens: Training token indices
        val_tokens: Validation token indices
        test_tokens: Test token indices
    """
    # Download data if needed
    data_dir = download_wikitext2(data_dir)
    
    # Load files
    wikitext_dir = data_dir
    train_text = load_wikitext_file(os.path.join(wikitext_dir, 'wiki.train.tokens'))
    val_text = load_wikitext_file(os.path.join(wikitext_dir, 'wiki.valid.tokens'))
    test_text = load_wikitext_file(os.path.join(wikitext_dir, 'wiki.test.tokens'))
    
    print(f"Train text length: {len(train_text)} characters")
    print(f"Val text length: {len(val_text)} characters")
    print(f"Test text length: {len(test_text)} characters")
    
    # Build vocabulary from training data
    tokenizer = WordTokenizer(vocab_size=vocab_size)
    tokenizer.build_vocab([train_text])
    
    # Tokenize all splits
    train_tokens = tokenizer.encode(train_text)
    val_tokens = tokenizer.encode(val_text)
    test_tokens = tokenizer.encode(test_text)
    
    # Limit training data for grokking experiment
    if train_samples is not None:
        # Calculate how many tokens we need for the desired number of samples
        tokens_needed = train_samples * (seq_length + 1)
        train_tokens = train_tokens[:tokens_needed]
        print(f"Limited training data to {train_samples} samples ({len(train_tokens)} tokens)")
    
    if val_samples is not None:
        tokens_needed = val_samples * (seq_length + 1)
        val_tokens = val_tokens[:tokens_needed]
    
    if test_samples is not None:
        tokens_needed = test_samples * (seq_length + 1)
        test_tokens = test_tokens[:tokens_needed]
    
    print(f"Train tokens: {len(train_tokens)}")
    print(f"Val tokens: {len(val_tokens)}")
    print(f"Test tokens: {len(test_tokens)}")
    
    return tokenizer, train_tokens, val_tokens, test_tokens


def create_wikitext_dataloaders(train_tokens, 
                                val_tokens, 
                                test_tokens,
                                seq_length=32,
                                batch_size=32):
    """
    Create DataLoaders for WikiText data.
    
    Args:
        train_tokens: Training token indices
        val_tokens: Validation token indices
        test_tokens: Test token indices
        seq_length: Length of each sequence
        batch_size: Batch size for DataLoaders
    
    Returns:
        train_loader, val_loader, test_loader
    """
    train_dataset = WikiTextDataset(train_tokens, seq_length)
    val_dataset = WikiTextDataset(val_tokens, seq_length)
    test_dataset = WikiTextDataset(test_tokens, seq_length)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                           num_workers=2, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                            num_workers=2, pin_memory=True, persistent_workers=True)
    
    print(f"Created DataLoaders:")
    print(f"  Train: {len(train_dataset)} sequences, {len(train_loader)} batches")
    print(f"  Val: {len(val_dataset)} sequences, {len(val_loader)} batches")
    print(f"  Test: {len(test_dataset)} sequences, {len(test_loader)} batches")
    
    return train_loader, val_loader, test_loader


def get_wikitext_experiment_setup(vocab_size=5000,
                                  seq_length=32,
                                  train_samples=500,  # Small for grokking
                                  val_samples=200,
                                  test_samples=200,
                                  batch_size=32):
    """
    One-stop function to get everything needed for WikiText experiment.
    
    Returns:
        tokenizer, train_loader, val_loader, test_loader, vocab_size
    """
    print("\n" + "="*50)
    print("Setting up WikiText-2 Language Modeling Experiment")
    print("="*50)
    
    tokenizer, train_tokens, val_tokens, test_tokens = prepare_wikitext_data(
        vocab_size=vocab_size,
        seq_length=seq_length,
        train_samples=train_samples,
        val_samples=val_samples,
        test_samples=test_samples
    )
    
    train_loader, val_loader, test_loader = create_wikitext_dataloaders(
        train_tokens,
        val_tokens,
        test_tokens,
        seq_length=seq_length,
        batch_size=batch_size
    )
    
    actual_vocab_size = len(tokenizer.word2idx)
    
    print("="*50 + "\n")
    
    return tokenizer, train_loader, val_loader, test_loader, actual_vocab_size


if __name__ == "__main__":
    # Test the data loading
    tokenizer, train_loader, val_loader, test_loader, vocab_size = get_wikitext_experiment_setup()
    
    # Print a sample batch
    batch = next(iter(train_loader))
    print("\nSample batch:")
    print(f"Input shape: {batch['input_ids'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")
    print(f"First sequence (tokens): {batch['input_ids'][0][:10]}")
    print(f"First sequence (words): {tokenizer.decode(batch['input_ids'][0][:10].tolist())}")

