"""Lazy loading utilities for large feature datasets.

This module provides memory-efficient loading of large feature datasets
using memory-mapping. This allows working with datasets larger than
available RAM.
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Iterator

import numpy as np
import torch
from torch import Tensor

if TYPE_CHECKING:
    from numpy.typing import NDArray


class LazyFeatureLoader:
    """Lazy loader for large feature datasets.

    This class provides memory-efficient access to large feature datasets
    by using memory-mapping. Features are loaded on-demand rather than
    all at once.

    Supports both NPY and NPZ formats, with NPZ supporting multiple arrays.

    Attributes:
        path: Path to the feature file.
        shape: Shape of the feature array.

    Example:
        >>> loader = LazyFeatureLoader("features.npy")
        >>> print(f"Dataset shape: {loader.shape}")
        >>> # Get specific samples (loaded on demand)
        >>> batch = loader[0:32]
    """

    def __init__(
        self,
        path: str | Path,
        array_key: str = "features",
    ) -> None:
        """Initialize lazy loader.

        Args:
            path: Path to the feature file (.npy or .npz).
            array_key: Key for array in NPZ files.
        """
        self.path = Path(path)
        self.array_key = array_key

        if not self.path.exists():
            raise FileNotFoundError(f"Feature file not found: {path}")

        # Load as memory-mapped
        if self.path.suffix == ".npz":
            # NPZ files need special handling
            self._npz = np.load(self.path, mmap_mode="r")
            self._data = self._npz[array_key]
        else:
            self._npz = None
            self._data = np.load(self.path, mmap_mode="r")

    @property
    def shape(self) -> tuple[int, ...]:
        """Shape of the feature array."""
        return self._data.shape

    @property
    def dtype(self) -> np.dtype:
        """Data type of features."""
        return self._data.dtype

    def __len__(self) -> int:
        """Number of samples."""
        return self._data.shape[0]

    def __getitem__(
        self,
        idx: int | slice | list[int] | "NDArray[np.int64]",
    ) -> "NDArray[np.floating]":
        """Get features at index/indices.

        Args:
            idx: Index, slice, or list of indices.

        Returns:
            Feature array.
        """
        return np.array(self._data[idx])

    def get_tensor(
        self,
        idx: int | slice | list[int],
        device: torch.device | str = "cpu",
    ) -> Tensor:
        """Get features as PyTorch tensor.

        Args:
            idx: Index, slice, or list of indices.
            device: Device for tensor.

        Returns:
            Feature tensor.
        """
        return torch.from_numpy(self[idx]).to(device)

    def iter_batches(
        self,
        batch_size: int = 32,
        shuffle: bool = False,
        drop_last: bool = False,
    ) -> Iterator["NDArray[np.floating]"]:
        """Iterate over batches.

        Args:
            batch_size: Number of samples per batch.
            shuffle: Whether to shuffle before iterating.
            drop_last: Whether to drop incomplete last batch.

        Yields:
            Feature arrays [batch_size, ...].
        """
        n = len(self)
        indices = np.arange(n)

        if shuffle:
            np.random.shuffle(indices)

        for start in range(0, n, batch_size):
            end = start + batch_size
            if end > n and drop_last:
                break
            batch_indices = indices[start:end]
            yield self[batch_indices]

    def iter_tensor_batches(
        self,
        batch_size: int = 32,
        device: torch.device | str = "cpu",
        shuffle: bool = False,
        drop_last: bool = False,
    ) -> Iterator[Tensor]:
        """Iterate over batches as tensors.

        Args:
            batch_size: Number of samples per batch.
            device: Device for tensors.
            shuffle: Whether to shuffle before iterating.
            drop_last: Whether to drop incomplete last batch.

        Yields:
            Feature tensors [batch_size, ...].
        """
        for batch in self.iter_batches(batch_size, shuffle, drop_last):
            yield torch.from_numpy(batch).to(device)

    def sample(
        self,
        n: int,
        replace: bool = False,
    ) -> "NDArray[np.floating]":
        """Sample random features.

        Args:
            n: Number of features to sample.
            replace: Whether to sample with replacement.

        Returns:
            Sampled features [n, ...].
        """
        indices = np.random.choice(len(self), size=n, replace=replace)
        return self[indices]

    def sample_tensor(
        self,
        n: int,
        device: torch.device | str = "cpu",
        replace: bool = False,
    ) -> Tensor:
        """Sample random features as tensor.

        Args:
            n: Number of features to sample.
            device: Device for tensor.
            replace: Whether to sample with replacement.

        Returns:
            Sampled features [n, ...].
        """
        return torch.from_numpy(self.sample(n, replace)).to(device)

    def get_channel_stats(self) -> dict[str, "NDArray[np.floating]"]:
        """Compute per-channel statistics.

        Returns:
            Dictionary with mean and std per channel.
        """
        # Process in chunks to avoid loading everything
        n_channels = self.shape[1]
        running_sum = np.zeros(n_channels)
        running_sq_sum = np.zeros(n_channels)
        count = 0

        for batch in self.iter_batches(batch_size=100):
            # Average over spatial dimensions
            batch_mean = batch.mean(axis=(2, 3))  # [B, K]
            running_sum += batch_mean.sum(axis=0)
            running_sq_sum += (batch_mean**2).sum(axis=0)
            count += batch.shape[0]

        mean = running_sum / count
        var = running_sq_sum / count - mean**2
        std = np.sqrt(np.maximum(var, 0))

        return {"mean": mean, "std": std}

    def close(self) -> None:
        """Close any open file handles."""
        if self._npz is not None:
            self._npz.close()
        self._data = None

    def __del__(self) -> None:
        self.close()

    def __repr__(self) -> str:
        return f"LazyFeatureLoader(path={self.path}, shape={self.shape})"
