"""
Memory Block Data Module for Sharded HDF5 Datasets
Loads entire shards into memory and performs in-memory shuffling for optimal performance.
Combines the benefits of fast I/O (reading entire shards) with runtime shuffling.
"""
import os
import glob
import h5py
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, IterableDataset
import lightning.pytorch as pl
from typing import List, Optional, Iterator
import gc
import time
from datetime import datetime


class MemoryBlockDataset(IterableDataset):
    """
    Dataset that loads entire shards into memory and shuffles them.
    Optimized for maximum I/O throughput with efficient in-memory shuffling.
    Supports optional high norm filtering for removing activation outliers.
    """

    def __init__(
        self,
        shard_files: List[str],
        block_size_gb: float = 4.0,  # Not used anymore, kept for compatibility
        start_shard: int = 0,
        end_shard: Optional[int] = None,
        dimensions: Optional[int] = None,
        remove_high_norm: Optional[float] = None,
        filter_cache_size: int = 10000,
        shuffle: bool = True,
        normalize_by_mean: bool = False,
        computed_mean_norm: Optional[float] = None,
        computed_median_norm: Optional[float] = None,
        computed_norm_threshold: Optional[float] = None
    ):
        """
        Initialize memory block dataset.

        Args:
            shard_files: List of shard file paths
            block_size_gb: (Deprecated) Previously used for block size, now loads entire shards
            start_shard: Starting shard index
            end_shard: Ending shard index (None = use all)
            dimensions: Data dimensions (will be inferred if None)
            remove_high_norm: If not None, filter samples with norm > median * this factor
            filter_cache_size: Number of samples to use for computing median norm
            shuffle: Whether to shuffle samples within each shard
            normalize_by_mean: Whether to normalize data by the mean norm
            computed_mean_norm: Pre-computed mean norm (if available from previous run)
            computed_median_norm: Pre-computed median norm (if available from previous run)
            computed_norm_threshold: Pre-computed threshold (if available from previous run)
        """
        self.shard_files = sorted(shard_files)
        self.block_size_gb = block_size_gb  # Kept for compatibility
        self.start_shard = start_shard
        self.end_shard = end_shard if end_shard is not None else len(self.shard_files)
        self.remove_high_norm = remove_high_norm
        self.filter_cache_size = filter_cache_size
        self.shuffle = shuffle
        self.normalize_by_mean = normalize_by_mean
        self.mean_norm = None  # Will be computed if needed

        # Infer dimensions if not provided
        if dimensions is None:
            with h5py.File(self.shard_files[0], 'r') as f:
                self.dimensions = f["non_padding_cache"].shape[1]
        else:
            self.dimensions = dimensions

        # Note: block_size_gb is deprecated, we now load entire shards
        # Keeping the parameter for backward compatibility

        # Use pre-computed values if available, otherwise compute them
        if computed_mean_norm is not None or computed_median_norm is not None or computed_norm_threshold is not None:
            print(f"Using pre-computed norm statistics from config:")
            self.mean_norm = computed_mean_norm
            self.median_norm = computed_median_norm
            self.norm_threshold = computed_norm_threshold
            if self.mean_norm is not None:
                print(f"  Mean norm: {self.mean_norm:.4f}")
            if self.median_norm is not None:
                print(f"  Median norm: {self.median_norm:.4f}")
            if self.norm_threshold is not None:
                print(f"  Threshold: {self.norm_threshold:.4f}")
        else:
            # Compute norm statistics if filtering or normalization is enabled
            self.norm_threshold = None
            if self.remove_high_norm is not None or self.normalize_by_mean:
                self._compute_norm_statistics()

    def _compute_norm_statistics(self):
        """Compute norm statistics (mean, median) for filtering and/or normalization."""
        print(f"Computing norm statistics...")
        if self.remove_high_norm is not None:
            print(f"  High norm filtering enabled: remove_high_norm={self.remove_high_norm}")
        if self.normalize_by_mean:
            print(f"  Mean normalization enabled")

        # Store median norm as well for reference
        self.median_norm = None

        # Collect samples to compute median and mean
        norms = []
        samples_collected = 0
        target_samples = self.filter_cache_size

        for shard_idx in range(self.start_shard, min(self.end_shard, self.start_shard + 5)):
            if samples_collected >= target_samples:
                break

            with h5py.File(self.shard_files[shard_idx], 'r') as f:
                data = f["non_padding_cache"]
                # Sample randomly from this shard
                num_samples = min(target_samples - samples_collected, data.shape[0])
                indices = np.random.choice(data.shape[0], num_samples, replace=False)

                for idx in indices:
                    sample = data[idx]
                    norm = np.linalg.norm(sample)
                    norms.append(norm)
                    samples_collected += 1

        # Compute median and mean
        self.median_norm = float(np.median(norms))
        self.mean_norm = float(np.mean(norms))

        # Set threshold if filtering is enabled
        if self.remove_high_norm is not None:
            self.norm_threshold = float(self.median_norm * self.remove_high_norm)
            print(f"  Median norm: {self.median_norm:.4f}, Threshold: {self.norm_threshold:.4f}")
            print(f"  Samples with norm > {self.norm_threshold:.4f} will be filtered out")

        # Print mean norm info if normalization is enabled
        if self.normalize_by_mean:
            print(f"  Mean norm: {self.mean_norm:.4f} (will be used for normalization)")
        else:
            print(f"  Mean norm: {self.mean_norm:.4f} (computed but not used for normalization)")

    def __iter__(self) -> Iterator[torch.Tensor]:
        """
        Iterate through data by loading sequential blocks into memory.
        Since data is pre-shuffled, we can read sequentially for maximum speed.
        """
        worker_info = torch.utils.data.get_worker_info()

        # Determine which shards this worker should process
        if worker_info is None:
            # Single-process loading
            worker_id = 0
            num_workers = 1
        else:
            worker_id = worker_info.id
            num_workers = worker_info.num_workers

        # Divide shards among workers
        shards_per_worker = (self.end_shard - self.start_shard) // num_workers
        extra_shards = (self.end_shard - self.start_shard) % num_workers

        if worker_id < extra_shards:
            worker_start_shard = self.start_shard + worker_id * (shards_per_worker + 1)
            worker_end_shard = worker_start_shard + shards_per_worker + 1
        else:
            worker_start_shard = self.start_shard + extra_shards * (shards_per_worker + 1) + \
                               (worker_id - extra_shards) * shards_per_worker
            worker_end_shard = worker_start_shard + shards_per_worker

        print(f"\n[Worker {worker_id}/{num_workers}] Assigned shards {worker_start_shard} to {worker_end_shard-1}")
        print(f"[Worker {worker_id}] Will process {worker_end_shard - worker_start_shard} shards")

        # Process assigned shards
        for shard_idx in range(worker_start_shard, worker_end_shard):
            yield from self._process_shard(shard_idx)

    def _process_shard(self, shard_idx: int) -> Iterator[torch.Tensor]:
        """Process a single shard by loading it entirely and shuffling."""
        shard_file = self.shard_files[shard_idx]
        shard_name = os.path.basename(shard_file)

        print(f"\n{'='*80}")
        print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Starting to process shard {shard_idx}")
        print(f"Shard file: {shard_name}")
        print(f"Full path: {shard_file}")

        start_time = time.time()

        with h5py.File(shard_file, 'r') as f:
            data = f["non_padding_cache"]
            total_samples = data.shape[0]
            data_shape = data.shape

            print(f"Shard shape: {data_shape}, dtype: {data.dtype}")
            print(f"Loading {total_samples:,} samples into memory...")

            load_start = time.time()
            # Load entire shard into memory at once
            shard_data = data[:]
            load_end = time.time()

            print(f"✓ Data loaded in {load_end - load_start:.2f} seconds")
            print(f"  Memory size: {shard_data.nbytes / (1024**3):.2f} GB")

            # Convert to tensor
            tensor_start = time.time()
            shard_tensor = torch.from_numpy(shard_data).float()
            del shard_data  # Free the numpy array
            tensor_end = time.time()

            print(f"✓ Converted to tensor in {tensor_end - tensor_start:.2f} seconds")

            # Apply filtering if enabled
            if self.norm_threshold is not None:
                filter_start = time.time()
                print(f"Applying high norm filtering with threshold {self.norm_threshold:.4f}")
                # Compute norms for all samples
                norms = torch.norm(shard_tensor, dim=1)
                # Create mask for samples to keep
                mask = norms <= self.norm_threshold
                # Filter the tensor
                filtered_tensor = shard_tensor[mask]
                samples_kept = filtered_tensor.shape[0]
                samples_removed = shard_tensor.shape[0] - samples_kept
                filter_end = time.time()
                print(f"✓ Filtering completed in {filter_end - filter_start:.2f} seconds")
                print(f"  Kept: {samples_kept:,}/{shard_tensor.shape[0]:,} samples")
                print(f"  Removed: {samples_removed:,} samples ({samples_removed/shard_tensor.shape[0]*100:.1f}%)")
                shard_tensor = filtered_tensor
                del filtered_tensor, mask, norms

            # Apply mean normalization if enabled
            if self.normalize_by_mean and self.mean_norm is not None and self.mean_norm > 0:
                norm_start = time.time()
                shard_tensor = shard_tensor / self.mean_norm
                norm_end = time.time()
                print(f"✓ Normalized by mean norm ({self.mean_norm:.4f}) in {norm_end - norm_start:.2f} seconds")

            # Shuffle the samples within the shard if enabled
            num_samples = shard_tensor.shape[0]
            if self.shuffle:
                shuffle_start = time.time()
                shuffle_indices = torch.randperm(num_samples)
                shard_tensor = shard_tensor[shuffle_indices]
                shuffle_end = time.time()
                print(f"✓ Shuffled {num_samples:,} samples in {shuffle_end - shuffle_start:.2f} seconds")
            else:
                print(f"Processing {num_samples:,} samples without shuffling")

            total_prep_time = time.time() - start_time
            print(f"Total shard preparation time: {total_prep_time:.2f} seconds")
            print(f"Starting to yield {num_samples:,} samples...")
            print(f"{'='*80}\n")

            # Track yielding progress
            yield_start = time.time()
            samples_yielded = 0
            last_log_time = time.time()
            log_interval = 100000  # Log every 100k samples

            # Yield samples one by one
            for i in range(num_samples):
                yield shard_tensor[i]
                samples_yielded += 1

                # Log progress periodically
                if samples_yielded % log_interval == 0:
                    current_time = time.time()
                    elapsed = current_time - yield_start
                    rate = samples_yielded / elapsed
                    print(f"  [{datetime.now().strftime('%H:%M:%S')}] Yielded {samples_yielded:,}/{num_samples:,} samples "
                          f"({samples_yielded/num_samples*100:.1f}%) - Rate: {rate:.0f} samples/sec")

            # Final summary for this shard
            yield_end = time.time()
            total_yield_time = yield_end - yield_start
            total_shard_time = yield_end - start_time

            print(f"\n{'='*80}")
            print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Completed shard {shard_idx} ({shard_name})")
            print(f"  Total samples yielded: {samples_yielded:,}")
            print(f"  Yielding time: {total_yield_time:.2f} seconds")
            print(f"  Yielding rate: {samples_yielded/total_yield_time:.0f} samples/sec")
            print(f"  Total shard processing time: {total_shard_time:.2f} seconds")
            print(f"{'='*80}\n")

            # Clear memory after shard is consumed
            del shard_tensor
            gc.collect()


class MemoryBlockBatchDataset(Dataset):
    """
    Dataset that pre-loads entire blocks into memory for batch processing.
    Best for when you have enough RAM to hold large blocks.
    """

    def __init__(
        self,
        shard_files: List[str],
        block_size_gb: float = 4.0,
        start_idx: int = 0,
        end_idx: Optional[int] = None,
        preload: bool = True
    ):
        """
        Initialize memory block batch dataset.

        Args:
            shard_files: List of shard file paths
            block_size_gb: Size of each memory block in GB
            start_idx: Global starting index
            end_idx: Global ending index
            preload: Whether to preload first block immediately
        """
        self.shard_files = sorted(shard_files)
        self.block_size_gb = block_size_gb

        # Build index mapping
        self._build_index_mapping(start_idx, end_idx)

        # Current loaded block
        self.current_block = None
        self.current_block_start = None
        self.current_block_end = None

        # Calculate block size
        with h5py.File(self.shard_files[0], 'r') as f:
            self.dimensions = f["non_padding_cache"].shape[1]

        bytes_per_sample = self.dimensions * 4
        self.samples_per_block = int(block_size_gb * 1024**3 / bytes_per_sample)

        if preload:
            self._load_block(0)

    def _build_index_mapping(self, start_idx: int, end_idx: Optional[int]):
        """Build mapping of global indices to shard locations."""
        self.shard_info = []
        self.cumulative_sizes = [0]

        for shard_file in self.shard_files:
            with h5py.File(shard_file, 'r') as f:
                shard_size = f["non_padding_cache"].shape[0]
                self.shard_info.append({
                    'file': shard_file,
                    'size': shard_size
                })
                self.cumulative_sizes.append(self.cumulative_sizes[-1] + shard_size)

        self.total_size = self.cumulative_sizes[-1]
        self.start_idx = start_idx
        self.end_idx = end_idx if end_idx is not None else self.total_size
        self.length = self.end_idx - self.start_idx

    def _load_block(self, global_block_idx: int):
        """Load a block of data into memory."""
        block_start = global_block_idx * self.samples_per_block
        block_end = min(block_start + self.samples_per_block, self.length)

        if block_start >= self.length:
            return False

        print(f"Loading memory block {global_block_idx}: samples {block_start:,}-{block_end:,}")

        # Collect samples for this block
        block_data = []

        for idx in range(block_start, block_end):
            global_idx = self.start_idx + idx

            # Find which shard this sample is in
            shard_idx = 0
            for i in range(len(self.cumulative_sizes) - 1):
                if self.cumulative_sizes[i] <= global_idx < self.cumulative_sizes[i + 1]:
                    shard_idx = i
                    break

            local_idx = global_idx - self.cumulative_sizes[shard_idx]

            # Load from shard (this could be optimized to load chunks)
            with h5py.File(self.shard_info[shard_idx]['file'], 'r') as f:
                sample = f["non_padding_cache"][local_idx]
                block_data.append(sample)

        self.current_block = torch.from_numpy(np.array(block_data)).float()
        self.current_block_start = block_start
        self.current_block_end = block_end

        return True

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        if idx >= self.length:
            raise IndexError(f"Index {idx} out of range")

        # Check if we need to load a new block
        if (self.current_block is None or
            idx < self.current_block_start or
            idx >= self.current_block_end):

            # Determine which block this index belongs to
            block_idx = idx // self.samples_per_block
            self._load_block(block_idx)

        # Return sample from current block
        local_idx = idx - self.current_block_start
        return self.current_block[local_idx]


class MemoryBlockDataModule(pl.LightningDataModule):
    """
    Lightning DataModule that loads entire shards into memory and shuffles them.
    Optimized for maximum I/O efficiency with in-memory shuffling.
    """

    def __init__(
        self,
        shard_pattern: str,
        batch_size: int = 8192,
        block_size_gb: float = 4.0,  # Deprecated, kept for compatibility
        num_workers: int = 2,
        split_ratio: float = 0.999,
        use_iterable: bool = True,
        prefetch_factor: int = 2,
        persistent_workers: bool = True,
        remove_high_norm: Optional[float] = None,
        filter_cache_size: int = 10000,
        shuffle: bool = True,
        normalize_by_mean: bool = False,
        computed_mean_norm: Optional[float] = None,
        computed_median_norm: Optional[float] = None,
        computed_norm_threshold: Optional[float] = None,
        **kwargs
    ):
        """
        Initialize memory block data module.

        Args:
            shard_pattern: Pattern to match shard files
            batch_size: Batch size
            block_size_gb: (Deprecated) Previously used for block size
            num_workers: Number of data loading workers
            split_ratio: Train/val split ratio
            use_iterable: Use IterableDataset (streaming) vs Dataset (indexed)
            prefetch_factor: Number of samples to prefetch per worker
            persistent_workers: Keep workers alive between epochs
            remove_high_norm: If not None, filter samples with norm > median * this factor
            filter_cache_size: Number of samples to use for computing median norm
            shuffle: Whether to shuffle samples within each shard
            normalize_by_mean: Whether to normalize data by the mean norm
        """
        super().__init__()
        self.shard_pattern = shard_pattern
        self.batch_size = batch_size
        self.block_size_gb = block_size_gb  # Kept for compatibility
        self.num_workers = num_workers
        self.split_ratio = split_ratio
        self.use_iterable = use_iterable
        self.prefetch_factor = prefetch_factor
        self.persistent_workers = persistent_workers and num_workers > 0
        self.remove_high_norm = remove_high_norm
        self.filter_cache_size = filter_cache_size
        self.shuffle = shuffle
        self.normalize_by_mean = normalize_by_mean
        self.computed_mean_norm = computed_mean_norm
        self.computed_median_norm = computed_median_norm
        self.computed_norm_threshold = computed_norm_threshold

        # Find shard files
        self.shard_files = sorted(glob.glob(shard_pattern))
        if not self.shard_files:
            raise ValueError(f"No shard files found matching pattern: {shard_pattern}")

        print(f"Found {len(self.shard_files)} shard files")
        print(f"Loading mode: Full shard into memory with {'shuffling' if shuffle else 'no shuffling'}")
        print(f"Batch size: {batch_size}")

        # Calculate split
        self.num_train_shards = int(len(self.shard_files) * split_ratio)
        self.num_val_shards = len(self.shard_files) - self.num_train_shards

    def setup(self, stage=None):
        """Setup datasets."""
        print(f"\n{'='*80}")
        print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Setting up MemoryBlockDataModule")
        print(f"{'='*80}")

        if self.remove_high_norm is not None:
            print(f"High norm filtering enabled: remove_high_norm={self.remove_high_norm}")
            print(f"Using {self.filter_cache_size:,} samples for median computation")

        if self.normalize_by_mean:
            print(f"Mean normalization enabled")

        if self.use_iterable:
            # Use streaming dataset (recommended for large data)
            self.train_ds = MemoryBlockDataset(
                shard_files=self.shard_files,
                block_size_gb=self.block_size_gb,
                start_shard=0,
                end_shard=self.num_train_shards,
                remove_high_norm=self.remove_high_norm,
                filter_cache_size=self.filter_cache_size,
                shuffle=self.shuffle,  # Enable shuffling for training
                normalize_by_mean=self.normalize_by_mean,
                computed_mean_norm=self.computed_mean_norm,
                computed_median_norm=self.computed_median_norm,
                computed_norm_threshold=self.computed_norm_threshold
            )

            self.val_ds = MemoryBlockDataset(
                shard_files=self.shard_files,
                block_size_gb=self.block_size_gb,
                start_shard=self.num_train_shards,
                end_shard=len(self.shard_files),
                remove_high_norm=self.remove_high_norm,
                filter_cache_size=self.filter_cache_size,
                shuffle=False,  # No shuffling for validation
                normalize_by_mean=self.normalize_by_mean,
                computed_mean_norm=self.computed_mean_norm,
                computed_median_norm=self.computed_median_norm,
                computed_norm_threshold=self.computed_norm_threshold
            )
        else:
            # Use indexed dataset (requires more memory)
            # Calculate total samples
            total_samples = 0
            for shard_file in self.shard_files:
                with h5py.File(shard_file, 'r') as f:
                    total_samples += f["non_padding_cache"].shape[0]

            train_samples = int(total_samples * self.split_ratio)

            self.train_ds = MemoryBlockBatchDataset(
                shard_files=self.shard_files,
                block_size_gb=self.block_size_gb,
                start_idx=0,
                end_idx=train_samples
            )

            self.val_ds = MemoryBlockBatchDataset(
                shard_files=self.shard_files,
                block_size_gb=self.block_size_gb,
                start_idx=train_samples,
                end_idx=total_samples
            )

        print(f"Train shards: {self.num_train_shards}, Val shards: {self.num_val_shards}")
        print(f"Dataset setup complete!")
        print(f"{'='*80}\n")

    def train_dataloader(self):
        """Create training dataloader."""
        print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Creating training dataloader")
        print(f"  Train dataset type: {type(self.train_ds).__name__}")
        print(f"  Batch size: {self.batch_size}")
        print(f"  Num workers: {self.num_workers}")
        print(f"  Persistent workers: {self.persistent_workers}")
        print(f"  Prefetch factor: {self.prefetch_factor if self.num_workers > 0 else None}")

        # Shuffling is handled within shards during data loading
        loader = DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=False,  # Shuffling happens within shards during loading
            num_workers=self.num_workers,
            pin_memory=True,
            prefetch_factor=self.prefetch_factor if self.num_workers > 0 else None,
            persistent_workers=self.persistent_workers,
            timeout=60  # Add timeout
        )
        print(f"  Training dataloader created successfully")
        return loader

    def val_dataloader(self):
        """Create validation dataloader."""
        print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Creating validation dataloader")
        print(f"  Val dataset type: {type(self.val_ds).__name__}")
        print(f"  Batch size: {self.batch_size}")
        print(f"  Num workers: {max(0, self.num_workers // 2)}")

        loader = DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=max(0, self.num_workers // 2),  # Fewer workers for val
            pin_memory=True,
            persistent_workers=False,  # Don't persist for validation
            timeout=60  # Add timeout of 60 seconds
        )
        print(f"  Validation dataloader created successfully")
        return loader

    def get_norm_statistics(self):
        """
        Get computed norm statistics for saving to config.
        Returns dict with mean_norm, median_norm, and norm_threshold (if applicable).
        """
        stats = {}

        # Get statistics from training dataset if it exists
        if hasattr(self, 'train_ds') and self.train_ds is not None:
            if hasattr(self.train_ds, 'mean_norm'):
                stats['computed_mean_norm'] = self.train_ds.mean_norm
            if hasattr(self.train_ds, 'median_norm'):
                stats['computed_median_norm'] = self.train_ds.median_norm
            if hasattr(self.train_ds, 'norm_threshold'):
                stats['computed_norm_threshold'] = self.train_ds.norm_threshold

        return stats