"""
Sharded HDF5 Data Module for SAE Training
Handles datasets that are split across multiple HDF5 shard files
Includes optional high norm filtering for removing activation outliers
"""
import os
import glob
import h5py
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import lightning.pytorch as pl
from typing import List, Optional, Tuple
import random
import warnings


class ShardedHDF5Dataset(Dataset):
    """Dataset that loads data from multiple HDF5 shard files with optional high norm filtering."""

    def __init__(
        self,
        shard_files: List[str],
        start_idx: int = 0,
        end_idx: Optional[int] = None,
        remove_high_norm: Optional[float] = None,
        filter_cache_size: int = 10000
    ):
        """
        Initialize the sharded dataset.

        Args:
            shard_files: List of paths to HDF5 shard files
            start_idx: Global starting index across all shards
            end_idx: Global ending index across all shards (None means use all data)
            remove_high_norm: If not None, remove samples with norm > median_norm * remove_high_norm
            filter_cache_size: Number of samples to load for computing median norm
        """
        self.shard_files = sorted(shard_files)
        self.h5_handles = {}
        self.remove_high_norm = remove_high_norm
        self.filter_cache_size = filter_cache_size

        # Build index mapping from global index to (shard_idx, local_idx)
        self._build_index_mapping(start_idx, end_idx)

        # If filtering is enabled, build filtered index mapping
        if self.remove_high_norm is not None:
            self._build_filtered_indices()

    def _build_index_mapping(self, start_idx: int, end_idx: Optional[int]):
        """Build mapping from global indices to shard and local indices."""
        self.shard_info = []
        self.cumulative_sizes = [0]

        # Get size of each shard
        for shard_file in self.shard_files:
            with h5py.File(shard_file, 'r') as f:
                shard_size = f["non_padding_cache"].shape[0]
                self.shard_info.append({
                    'file': shard_file,
                    'size': shard_size
                })
                self.cumulative_sizes.append(self.cumulative_sizes[-1] + shard_size)

        self.total_size = self.cumulative_sizes[-1]

        # Apply start and end indices
        self.start_idx = start_idx
        self.end_idx = end_idx if end_idx is not None else self.total_size
        self.original_length = self.end_idx - self.start_idx

    def _build_filtered_indices(self):
        """Build a filtered index mapping that excludes high norm samples."""
        print(f"Building filtered indices with remove_high_norm={self.remove_high_norm}")

        # Sample a subset of data to compute median norm
        sample_indices = np.random.choice(
            self.original_length,
            min(self.filter_cache_size, self.original_length),
            replace=False
        )

        # Load samples and compute norms
        norms = []
        for idx in sample_indices:
            shard_idx, local_idx = self._global_to_local_idx(idx)
            shard_file = self.shard_info[shard_idx]['file']

            with h5py.File(shard_file, 'r') as f:
                sample = f["non_padding_cache"][local_idx]
                norm = np.linalg.norm(sample)
                norms.append(norm)

        # Compute median norm
        median_norm = np.median(norms)
        threshold = median_norm * self.remove_high_norm

        print(f"Median norm: {median_norm:.4f}, threshold: {threshold:.4f}")

        # Build filtered index list
        self.filtered_indices = []

        # Check all samples (in batches for efficiency)
        batch_size = 1000
        num_batches = (self.original_length + batch_size - 1) // batch_size

        for batch_idx in range(num_batches):
            if batch_idx % 100 == 0:
                print(f"Filtering batch {batch_idx}/{num_batches}")

            start = batch_idx * batch_size
            end = min((batch_idx + 1) * batch_size, self.original_length)

            for idx in range(start, end):
                shard_idx, local_idx = self._global_to_local_idx(idx)
                shard_file = self.shard_info[shard_idx]['file']

                # Check norm
                h5_handle = self._get_h5_handle(shard_file)
                sample = h5_handle["non_padding_cache"][local_idx]
                norm = np.linalg.norm(sample)

                if norm <= threshold:
                    self.filtered_indices.append(idx)

        filtered_ratio = len(self.filtered_indices) / self.original_length
        print(f"Kept {len(self.filtered_indices)}/{self.original_length} samples ({filtered_ratio:.2%})")

        # Update length
        self.length = len(self.filtered_indices)

    def _get_h5_handle(self, shard_file: str):
        """Get or create H5 file handle (cached for efficiency)."""
        if shard_file not in self.h5_handles:
            self.h5_handles[shard_file] = h5py.File(shard_file, 'r')
        return self.h5_handles[shard_file]

    def _global_to_local_idx(self, global_idx: int) -> Tuple[int, int]:
        """Convert global index to (shard_idx, local_idx)."""
        actual_idx = self.start_idx + global_idx

        # Binary search to find the right shard
        shard_idx = 0
        for i in range(len(self.cumulative_sizes) - 1):
            if self.cumulative_sizes[i] <= actual_idx < self.cumulative_sizes[i + 1]:
                shard_idx = i
                break

        local_idx = actual_idx - self.cumulative_sizes[shard_idx]
        return shard_idx, local_idx

    def __len__(self):
        if self.remove_high_norm is not None:
            return self.length
        else:
            return self.original_length

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self)}")

        # Map through filtered indices if filtering is enabled
        if self.remove_high_norm is not None:
            actual_idx = self.filtered_indices[idx]
        else:
            actual_idx = idx

        shard_idx, local_idx = self._global_to_local_idx(actual_idx)
        shard_file = self.shard_info[shard_idx]['file']
        h5_handle = self._get_h5_handle(shard_file)

        return h5_handle["non_padding_cache"][local_idx]

    def close(self):
        """Close all open file handles."""
        for handle in self.h5_handles.values():
            handle.close()
        self.h5_handles = {}

    def __del__(self):
        """Ensure file handles are closed when object is deleted."""
        self.close()


class ShardedBufferedBatchHDF5DataModule(pl.LightningDataModule):
    """
    [DEPRECATED] DataModule for sharded HDF5 files with buffering support.

    This class is deprecated. Please use:
    - MemoryBlockDataModule for pre-shuffled data (best performance)
    - DirectShardedHDF5DataModule for data requiring runtime shuffling

    Loads one buffer at a time from sharded data.
    Supports optional high norm filtering.
    """

    def __init__(
        self,
        shard_pattern: str,
        buffer_size_samples: int,
        batch_size: int = 128,
        num_workers: int = 0,
        split_ratio: float = 0.9,
        seed: int = 42,
        remove_high_norm: Optional[float] = None,
        filter_cache_size: int = 10000,
        **kwargs
    ):
        """
        Initialize the sharded data module.

        Args:
            shard_pattern: Pattern to match shard files (e.g., "/path/to/data_shard_*.h5")
            buffer_size_samples: Number of samples to load into memory at once
            batch_size: Batch size for DataLoader
            num_workers: Number of workers for DataLoader
            split_ratio: Train/val split ratio
            seed: Random seed
            remove_high_norm: If not None, filter samples with norm > median * this factor
            filter_cache_size: Number of samples to use for computing median norm
        """
        super().__init__()
        warnings.warn(
            "ShardedBufferedBatchHDF5DataModule is deprecated. "
            "Use MemoryBlockDataModule for pre-shuffled data (best performance) or "
            "DirectShardedHDF5DataModule for data requiring runtime shuffling.",
            DeprecationWarning,
            stacklevel=2
        )
        self.shard_pattern = shard_pattern
        self.buffer_size_samples = int(buffer_size_samples)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.split_ratio = split_ratio
        self.seed = seed
        self.remove_high_norm = remove_high_norm
        self.filter_cache_size = filter_cache_size

        # Find all shard files
        self.shard_files = sorted(glob.glob(shard_pattern))
        if not self.shard_files:
            raise ValueError(f"No shard files found matching pattern: {shard_pattern}")

        print(f"Found {len(self.shard_files)} shard files")

        # Get total number of samples across all shards
        self.total_samples = 0
        for shard_file in self.shard_files:
            with h5py.File(shard_file, 'r') as f:
                self.total_samples += f["non_padding_cache"].shape[0]

        print(f"Total samples across all shards: {self.total_samples}")

        # Calculate split indices
        self.train_sample_end = int(self.split_ratio * self.total_samples)
        self.val_sample_start = self.train_sample_end

        # Buffer management
        self.current_buffer_start = 0
        self.train_ds = None
        self.val_ds = None

    def setup(self, stage=None):
        """Setup datasets."""
        # Set random seed for reproducibility
        if self.seed is not None:
            np.random.seed(self.seed)
            random.seed(self.seed)

        # Load initial training buffer
        self.load_current_buffer()

        # Create validation dataset (loads all validation data)
        self.val_ds = ShardedHDF5Dataset(
            self.shard_files,
            start_idx=self.val_sample_start,
            end_idx=self.total_samples,
            remove_high_norm=self.remove_high_norm,
            filter_cache_size=self.filter_cache_size
        )
        print(f"Validation set: {len(self.val_ds)} samples")

    def load_current_buffer(self):
        """Load the current buffer of training data."""
        start = self.current_buffer_start
        end = min(self.current_buffer_start + self.buffer_size_samples, self.train_sample_end)

        print(f"Loading buffer: samples {start} to {end}")

        # Clean up previous dataset if exists
        if self.train_ds is not None and hasattr(self.train_ds, 'close'):
            self.train_ds.close()

        # Create new training dataset for current buffer
        self.train_ds = ShardedHDF5Dataset(
            self.shard_files,
            start_idx=start,
            end_idx=end,
            remove_high_norm=self.remove_high_norm,
            filter_cache_size=self.filter_cache_size
        )

    def update_buffer(self):
        """Move to the next buffer of training data."""
        self.current_buffer_start += self.buffer_size_samples

        if self.current_buffer_start >= self.train_sample_end:
            print("All buffers processed. Restarting from beginning.")
            self.current_buffer_start = 0

        self.load_current_buffer()

    def train_dataloader(self, **kwargs):
        """Create training DataLoader."""
        return DataLoader(
            self.train_ds,
            batch_size=kwargs.get('batch_size', self.batch_size),
            shuffle=True,
            num_workers=kwargs.get('num_workers', self.num_workers),
            pin_memory=True,
            persistent_workers=False  # Don't persist workers due to buffer updates
        )

    def val_dataloader(self, **kwargs):
        """Create validation DataLoader."""
        return DataLoader(
            self.val_ds,
            batch_size=kwargs.get('batch_size', self.batch_size),
            shuffle=False,
            num_workers=0,  # Use single worker for validation
            pin_memory=True,
            persistent_workers=False
        )


class DirectShardedHDF5DataModule(pl.LightningDataModule):
    """
    DataModule for sharded HDF5 files with runtime shuffling support.

    This module is recommended for:
    - Data that requires runtime shuffling (not pre-shuffled)
    - Smaller datasets where indexed access is needed
    - Cases where you need full control over shuffle behavior

    For pre-shuffled data, consider using MemoryBlockDataModule for better performance.

    Uses sequential train/val split to avoid memory issues with large datasets.
    Supports optional high norm filtering.
    """

    def __init__(
        self,
        shard_pattern: str,
        batch_size: int = 128,
        num_workers: int = 4,
        split_ratio: float = 0.9,
        seed: int = 42,
        shuffle_train: bool = True,
        remove_high_norm: Optional[float] = None,
        filter_cache_size: int = 10000,
        **kwargs
    ):
        """
        Initialize the direct sharded data module.

        Args:
            shard_pattern: Pattern to match shard files
            batch_size: Batch size for DataLoader
            num_workers: Number of workers for DataLoader
            split_ratio: Train/val split ratio (sequential split)
            seed: Random seed for shuffling within batches
            shuffle_train: Whether to shuffle training data
            remove_high_norm: If not None, filter samples with norm > median * this factor
            filter_cache_size: Number of samples to use for computing median norm
        """
        super().__init__()
        self.shard_pattern = shard_pattern
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.split_ratio = split_ratio
        self.seed = seed
        self.shuffle_train = shuffle_train
        self.remove_high_norm = remove_high_norm
        self.filter_cache_size = filter_cache_size

        # Find all shard files
        self.shard_files = sorted(glob.glob(shard_pattern))
        if not self.shard_files:
            raise ValueError(f"No shard files found matching pattern: {shard_pattern}")

        print(f"Found {len(self.shard_files)} shard files")

        # Get total number of samples
        self.total_samples = 0
        for shard_file in self.shard_files:
            with h5py.File(shard_file, 'r') as f:
                self.total_samples += f["non_padding_cache"].shape[0]

        print(f"Total samples across all shards: {self.total_samples}")

    def setup(self, stage=None):
        """Setup train and validation datasets with sequential splitting."""
        # Set random seed for reproducibility
        if self.seed is not None:
            np.random.seed(self.seed)
            random.seed(self.seed)

        # Calculate split point
        train_size = int(self.total_samples * self.split_ratio)

        # Create train and validation datasets with sequential splits
        # This avoids creating large permutation arrays
        self.train_ds = ShardedHDF5Dataset(
            self.shard_files,
            start_idx=0,
            end_idx=train_size,
            remove_high_norm=self.remove_high_norm,
            filter_cache_size=self.filter_cache_size
        )

        self.val_ds = ShardedHDF5Dataset(
            self.shard_files,
            start_idx=train_size,
            end_idx=self.total_samples,
            remove_high_norm=self.remove_high_norm,
            filter_cache_size=self.filter_cache_size
        )

        if self.remove_high_norm is None:
            print(f"Train set: {len(self.train_ds)} samples (0 to {train_size})")
            print(f"Val set: {len(self.val_ds)} samples ({train_size} to {self.total_samples})")
        else:
            print(f"Train set: {len(self.train_ds)} samples (filtered)")
            print(f"Val set: {len(self.val_ds)} samples (filtered)")

    def train_dataloader(self):
        """Create training DataLoader."""
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=self.shuffle_train,  # Use the configured shuffle setting
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

    def val_dataloader(self):
        """Create validation DataLoader."""
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )