import json
import zstandard as zstd
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, IterableDataset
from transformers import GPT2Tokenizer, AutoTokenizer
import pytorch_lightning as pl
from typing import List, Dict, Iterator, Union, Tuple
import math
import random
from collections import deque
import io
import time

class PileIterableDataset(IterableDataset):
    """Iterable dataset for The Pile data with batch-level distribution and sequential file processing."""
    
    def __init__(
        self, 
        file_paths: List[str], 
        tokenizer, 
        max_length: int = 2049,
        shuffle_files: bool = True,
        rank: int = 0,
        world_size: int = 1,
        buffer_size: int = 100000  
    ):
        self.file_paths = file_paths
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.shuffle_files = shuffle_files
        self.rank = rank
        self.world_size = world_size
        self.eod_token_id = tokenizer.eos_token_id
        self.buffer_size = buffer_size
        
        self._token_buffer = np.zeros(buffer_size + max_length, dtype=np.int32)
        self._buffer_size = 0
    
    def parse_jsonl_zst(self, file_path: str) -> Iterator[Dict]:
        """Parse a .jsonl.zst file without document-level sharding."""
        chunk_size = 16 * 1024 * 1024  
        
        with open(file_path, 'rb') as f_in:
            dctx = zstd.ZstdDecompressor()
            with dctx.stream_reader(f_in) as reader:
                text_stream = io.TextIOWrapper(reader, encoding='utf-8')
                for line in text_stream:
                    if line.strip():
                        try:
                            yield json.loads(line)
                        except json.JSONDecodeError:
                            continue
    
    def _process_token_buffer(self, batch_index: int, global_worker_id: int, global_num_workers: int) -> Tuple[List[Dict], int]:
        """Process token buffer efficiently."""
        
        if self._buffer_size < self.max_length:
            return [], batch_index
        
        
        actual_buffer = self._token_buffer[:self._buffer_size]
        if self.shuffle_files:
            
            chunk_size = 50
            num_chunks = self._buffer_size // chunk_size
            
            if num_chunks > 1:
                
                
                chunk_indices = np.arange(num_chunks)
                np.random.shuffle(chunk_indices)
                
                
                if not np.array_equal(chunk_indices, np.arange(num_chunks)):
                    
                    shuffled = np.zeros_like(actual_buffer[:num_chunks*chunk_size])
                    for i, orig_idx in enumerate(chunk_indices):
                        shuffled[i*chunk_size:(i+1)*chunk_size] = actual_buffer[orig_idx*chunk_size:(orig_idx+1)*chunk_size]
                    
                    
                    actual_buffer[:num_chunks*chunk_size] = shuffled
        
        
        num_sequences = self._buffer_size // self.max_length
        sequences = []
        
        
        for i in range(num_sequences):
            sequence = actual_buffer[i * self.max_length:(i+1) * self.max_length]
            
            
            current_batch = batch_index
            batch_index += 1
            
            if current_batch % global_num_workers == global_worker_id:
                
                
                torch_sequence = torch.from_numpy(sequence.astype(np.int64))
                input_ids = torch_sequence[:-1]
                labels = torch_sequence[1:]
                attention_mask = torch.ones_like(input_ids)
                
                sequences.append({
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "labels": labels
                })
        
        
        leftover_start = num_sequences * self.max_length
        leftover_count = self._buffer_size - leftover_start
        
        if leftover_count > 0:
            
            self._token_buffer[:leftover_count] = actual_buffer[leftover_start:self._buffer_size]
            
        self._buffer_size = leftover_count
        
        return sequences, batch_index
    
    def _add_to_buffer(self, tokens):
        """Add tokens to the buffer efficiently."""
        length = len(tokens)
        
        
        if self._buffer_size + length > len(self._token_buffer):
            
            new_size = max(2 * len(self._token_buffer), self._buffer_size + length)
            new_buffer = np.zeros(new_size, dtype=np.int32)
            new_buffer[:self._buffer_size] = self._token_buffer[:self._buffer_size]
            self._token_buffer = new_buffer
            
        
        self._token_buffer[self._buffer_size:self._buffer_size + length] = tokens
        self._buffer_size += length

    def __iter__(self):
        start_time = time.time()
        
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            num_workers = 1
            worker_id = 0
            seed = 42
        else:
            num_workers = worker_info.num_workers
            worker_id = worker_info.id
            
            seed = 42 + worker_id + (worker_info.num_workers * self.rank)
        
        
        random.seed(seed)
        np.random.seed(seed)
        
        
        global_worker_id = self.rank * num_workers + worker_id
        global_num_workers = self.world_size * num_workers
        
        print(f"Dataset iterator: rank={self.rank}, worker={worker_id}, "
              f"global_worker={global_worker_id}, global_workers={global_num_workers}, "
              f"shuffle={self.shuffle_files}, seed={seed}")
        
        
        files_to_process = list(self.file_paths)
        if self.shuffle_files:
            random.shuffle(files_to_process)
            print(f"Rank {self.rank}, worker {worker_id}: Shuffled file order with seed {seed}")
        
        
        self._buffer_size = 0
        
        
        batch_index = 0
        
        
        for file_idx, file_path in enumerate(files_to_process):
            file_start_time = time.time()
            print(f"Rank {self.rank}, worker {worker_id}: Starting file {file_path} ({file_idx+1}/{len(files_to_process)})")
            
            try:
                
                doc_count = 0
                token_count = 0
                
                for item in self.parse_jsonl_zst(file_path):
                    if "text" in item:
                        
                        text = item["text"]
                        tokens = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)
                        token_ids = tokens["input_ids"].squeeze().cpu().numpy()
                        
                        
                        if token_ids.ndim == 0:
                            token_ids = np.array([token_ids.item()], dtype=np.int32)
                        
                        
                        self._add_to_buffer(token_ids)
                        self._add_to_buffer(np.array([self.eod_token_id], dtype=np.int32))
                        
                        token_count += len(token_ids) + 1
                        doc_count += 1
                        
                        
                        if self._buffer_size >= self.buffer_size:
                            sequences, batch_index = self._process_token_buffer(
                                batch_index, global_worker_id, global_num_workers
                            )
                            
                            
                            for seq in sequences:
                                yield seq
                        
                        
                        if doc_count % 1000 == 0:
                            elapsed = time.time() - file_start_time
                            
                                 
                
                file_elapsed = time.time() - file_start_time
                
                
                
                     
            except Exception as e:
                print(f"Error processing {file_path}: {e}")
                continue
        
        
        if self._buffer_size > 0:
            sequences, batch_index = self._process_token_buffer(
                batch_index, global_worker_id, global_num_workers
            )
            
            
            for seq in sequences:
                yield seq
        
        
        if self._buffer_size > 0:
            if self._buffer_size < self.max_length:
                
                padding_needed = self.max_length - self._buffer_size
                self._add_to_buffer(np.array([self.eod_token_id] * padding_needed, dtype=np.int32))
            
            if self._buffer_size >= self.max_length:
                sequence = self._token_buffer[:self.max_length]
                
                
                if batch_index % global_num_workers == global_worker_id:
                    torch_sequence = torch.from_numpy(sequence.astype(np.int64))
                    input_ids = torch_sequence[:-1]
                    labels = torch_sequence[1:]
                    attention_mask = torch.ones_like(input_ids)
                    
                    yield {
                        "input_ids": input_ids,
                        "attention_mask": attention_mask,
                        "labels": labels
                    }
        
        total_elapsed = time.time() - start_time
        print(f"Rank {self.rank}, worker {worker_id}: Dataset iteration completed in {total_elapsed:.1f}s")


class PileDataModule(pl.LightningDataModule):
    """PyTorch Lightning data module for The Pile dataset."""
    
    def __init__(
        self,
        train_dir: str,
        val_file: str,
        tokenizer,
        max_length: int = 2049,
        batch_size: int = 8,
        num_workers: int = 4,
        shuffle: bool = True,  
        process_all_files: bool = True,  
        max_files: int = None,  
        buffer_size: int = 100000  
    ):
        super().__init__()
        self.train_dir = train_dir
        self.val_file = val_file
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.process_all_files = process_all_files
        self.max_files = max_files
        self.buffer_size = buffer_size
        
        
        self.train_files = sorted([
            os.path.join(train_dir, f"{i:02d}.jsonl.zst") 
            for i in range(30) 
            if os.path.exists(os.path.join(train_dir, f"{i:02d}.jsonl.zst"))
        ])
        
        
        if not process_all_files and self.train_files:
            self.train_files = [self.train_files[0]]
            print(f"Using only the first file: {self.train_files[0]}")
        
        
        if max_files and len(self.train_files) > max_files:
            self.train_files = self.train_files[:max_files]
            print(f"Limited to first {max_files} files")
        
        
        if not os.path.exists(val_file):
            raise FileNotFoundError(f"Validation file not found: {val_file}")
        
        print(f"Found {len(self.train_files)} training files")
        print(f"Using validation file: {val_file}")
        print(f"Shuffling is {'enabled' if shuffle else 'disabled'}")
        print(f"Using token buffer size of {buffer_size}")
    
    def train_dataloader(self):
        
        rank = 0
        world_size = 1
        
        if self.trainer:
            if hasattr(self.trainer, 'global_rank'):
                rank = self.trainer.global_rank
            if hasattr(self.trainer, 'world_size'):
                world_size = self.trainer.world_size
        
        print(f"Creating train dataloader with rank={rank}, world_size={world_size}, shuffle={self.shuffle}")
        
        train_dataset = PileIterableDataset(
            file_paths=self.train_files,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle_files=self.shuffle,  
            rank=rank,
            world_size=world_size,
            buffer_size=self.buffer_size
        )
        
        return DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True
        )
    
    def val_dataloader(self):
        
        rank = 0
        world_size = 1
        
        if self.trainer:
            if hasattr(self.trainer, 'global_rank'):
                rank = self.trainer.global_rank
            if hasattr(self.trainer, 'world_size'):
                world_size = self.trainer.world_size
        
        val_dataset = PileIterableDataset(
            file_paths=[self.val_file],
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle_files=False,  
            rank=rank,
            world_size=world_size,
            buffer_size=self.buffer_size  
        )
        
        return DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True
        )



def setup_pile_data_for_training(config):
    """Setup data module for The Pile dataset."""
    
    
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    
    
    train_dir = "/disk/u/datasets/datasets/pile/the-eye.eu/public/AI/pile/train"
    val_file = "/disk/u/datasets/datasets/pile/the-eye.eu/public/AI/pile/val.jsonl.zst"
    
    
    data_module = PileDataModule(
        train_dir=train_dir,
        val_file=val_file,
        tokenizer=tokenizer,
        max_length=2049,
        batch_size=config.get("train_micro_batch_size_per_gpu"),
        num_workers=config.get("num_workers", 2),
        shuffle=config.get("shuffle", True),  
        process_all_files=config.get("process_all_files", True),
        max_files=config.get("max_files"),
        buffer_size=config.get("buffer_size", 100000)  
    )
    
    return data_module