"""Storage backends for probe data.

This module provides storage backends for persisting collected probe data:
- NPZStorage: NumPy compressed format (simple, portable)
- HDF5Storage: HDF5 format (efficient for large datasets)
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:
    from numpy.typing import NDArray

    from expected_gradcam.probes.collector import IntermediateCollector


class StorageBackend(ABC):
    """Abstract base class for storage backends.

    Storage backends persist collected probe data to disk. Different
    backends offer different tradeoffs:
    - NPZ: Simple, portable, good for small/medium datasets
    - HDF5: Efficient for large datasets, supports partial reads
    """

    @abstractmethod
    def save(
        self,
        sample_id: str,
        data: dict[str, Any],
    ) -> None:
        """Save data for a sample.

        Args:
            sample_id: Unique identifier for the sample.
            data: Dictionary of arrays and metadata to save.
        """
        ...

    @abstractmethod
    def load(self, sample_id: str) -> dict[str, Any]:
        """Load data for a sample.

        Args:
            sample_id: Unique identifier for the sample.

        Returns:
            Dictionary of loaded arrays and metadata.
        """
        ...

    @abstractmethod
    def list_samples(self) -> list[str]:
        """List all stored samples.

        Returns:
            List of sample IDs.
        """
        ...

    @abstractmethod
    def close(self) -> None:
        """Close any open resources."""
        ...


class NPZStorage(StorageBackend):
    """Storage backend using NumPy NPZ format.

    Each sample is stored as a separate .npz file in the output directory.
    This is simple and portable but may be slow for very large datasets.

    Attributes:
        output_dir: Directory to store NPZ files.
        compress: Whether to use compression.

    Example:
        >>> storage = NPZStorage("experiment_data/")
        >>> storage.save("sample_001", {"weights": weights_array})
        >>> data = storage.load("sample_001")
    """

    def __init__(
        self,
        output_dir: str | Path,
        compress: bool = True,
    ) -> None:
        """Initialize NPZ storage.

        Args:
            output_dir: Directory to store NPZ files.
            compress: Whether to use compression.
        """
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.compress = compress

    def _get_path(self, sample_id: str) -> Path:
        """Get file path for sample."""
        return self.output_dir / f"{sample_id}.npz"

    def save(
        self,
        sample_id: str,
        data: dict[str, Any],
    ) -> None:
        """Save data to NPZ file.

        Args:
            sample_id: Unique identifier for the sample.
            data: Dictionary of arrays to save.
        """
        path = self._get_path(sample_id)

        # Convert tensors to numpy
        numpy_data = {}
        for key, value in data.items():
            if hasattr(value, "cpu"):
                # PyTorch tensor
                numpy_data[key] = value.detach().cpu().numpy()
            elif isinstance(value, np.ndarray):
                numpy_data[key] = value
            elif isinstance(value, (int, float, str, bool)):
                numpy_data[key] = np.array(value)
            elif isinstance(value, dict):
                # Skip nested dicts for NPZ
                continue
            else:
                # Try to convert to array
                try:
                    numpy_data[key] = np.array(value)
                except Exception:
                    continue

        if self.compress:
            np.savez_compressed(path, **numpy_data)
        else:
            np.savez(path, **numpy_data)

    def load(self, sample_id: str) -> dict[str, Any]:
        """Load data from NPZ file.

        Args:
            sample_id: Unique identifier for the sample.

        Returns:
            Dictionary of loaded arrays.
        """
        path = self._get_path(sample_id)

        if not path.exists():
            raise FileNotFoundError(f"Sample not found: {sample_id}")

        data = np.load(path, allow_pickle=True)
        return dict(data)

    def list_samples(self) -> list[str]:
        """List all stored samples.

        Returns:
            List of sample IDs.
        """
        return [
            p.stem for p in self.output_dir.glob("*.npz")
        ]

    def close(self) -> None:
        """No resources to close for NPZ storage."""
        pass


class HDF5Storage(StorageBackend):
    """Storage backend using HDF5 format.

    All samples are stored in a single HDF5 file with each sample as a group.
    This is more efficient for large datasets and supports partial reads.

    Attributes:
        file_path: Path to HDF5 file.

    Example:
        >>> storage = HDF5Storage("experiment_data.h5")
        >>> storage.save("sample_001", {"weights": weights_array})
        >>> data = storage.load("sample_001")
        >>> storage.close()
    """

    def __init__(
        self,
        file_path: str | Path,
        mode: str = "a",
    ) -> None:
        """Initialize HDF5 storage.

        Args:
            file_path: Path to HDF5 file.
            mode: File mode ('r', 'r+', 'w', 'a').
        """
        try:
            import h5py

            self._h5py = h5py
        except ImportError as e:
            raise ImportError(
                "HDF5Storage requires h5py. Install with: pip install h5py"
            ) from e

        self.file_path = Path(file_path)
        self.file_path.parent.mkdir(parents=True, exist_ok=True)
        self._file = self._h5py.File(self.file_path, mode)

    def save(
        self,
        sample_id: str,
        data: dict[str, Any],
    ) -> None:
        """Save data to HDF5 file.

        Args:
            sample_id: Unique identifier for the sample.
            data: Dictionary of arrays to save.
        """
        # Create or overwrite group for this sample
        if sample_id in self._file:
            del self._file[sample_id]

        group = self._file.create_group(sample_id)

        for key, value in data.items():
            if hasattr(value, "cpu"):
                # PyTorch tensor
                array = value.detach().cpu().numpy()
            elif isinstance(value, np.ndarray):
                array = value
            elif isinstance(value, (int, float)):
                array = np.array(value)
            elif isinstance(value, str):
                # Store strings as attributes
                group.attrs[key] = value
                continue
            elif isinstance(value, dict):
                # Store dicts as JSON string attribute
                import json

                group.attrs[key] = json.dumps(value)
                continue
            else:
                try:
                    array = np.array(value)
                except Exception:
                    continue

            # Create dataset with compression
            group.create_dataset(
                key,
                data=array,
                compression="gzip",
                compression_opts=4,
            )

        self._file.flush()

    def load(self, sample_id: str) -> dict[str, Any]:
        """Load data from HDF5 file.

        Args:
            sample_id: Unique identifier for the sample.

        Returns:
            Dictionary of loaded arrays.
        """
        if sample_id not in self._file:
            raise FileNotFoundError(f"Sample not found: {sample_id}")

        group = self._file[sample_id]
        data = {}

        # Load datasets
        for key in group.keys():
            data[key] = group[key][:]

        # Load attributes
        for key, value in group.attrs.items():
            if isinstance(value, str) and value.startswith("{"):
                import json

                try:
                    data[key] = json.loads(value)
                except json.JSONDecodeError:
                    data[key] = value
            else:
                data[key] = value

        return data

    def list_samples(self) -> list[str]:
        """List all stored samples.

        Returns:
            List of sample IDs.
        """
        return list(self._file.keys())

    def close(self) -> None:
        """Close HDF5 file."""
        if self._file:
            self._file.close()
            self._file = None

    def __del__(self) -> None:
        """Ensure file is closed on deletion."""
        self.close()

    def __enter__(self) -> "HDF5Storage":
        """Context manager entry."""
        return self

    def __exit__(self, *args: Any) -> None:
        """Context manager exit."""
        self.close()
