"""Torch-related utility functions."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Literal

import cv2
import h5py
import numpy as np
import torch
from PIL import Image

from pathfmtools.io.schema import StoreKeys as SK

if TYPE_CHECKING:
    from collections.abc import Callable

    from pathfmtools.image import Slide
    from pathfmtools.image.tile_index import TileIndex
    from pathfmtools.io.slide_data_store import SlideDataStore

logger = logging.getLogger(__name__)


class EmbeddingDataset(torch.utils.data.Dataset):
    """A torch Dataset corresponding to a list of slides and a collection of associated labels."""

    def __init__(
        self,
        slides: list[Slide],
        labels: dict[str, float | np.ndarray],
        feature_type: Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"],
        model_name: str,
    ) -> None:
        """Create a torch Dataset from a list of slides and a collection of associated labels.

        Args:
            slides (list[Slide]): A list of Slide instances.
            labels (dict[str, np.ndarray]): A dictionary with keys that are slide IDs and values
                that are numpy arrays of labels.
            feature_type (Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"]): The
                type of features to use.
            model_name (str): The name of the model to use for feature extraction.

        Raises:
            ValueError: If the number of labels does not match the number of slides.
            TypeError: If the labels are not a numpy array or a dictionary.
            ValueError: If the feature type is not one of "patch_feature_embeddings" or
                "patch_zeroshot_embeddings".

        """
        if isinstance(labels, dict):
            if len(labels) != len(slides):
                msg = (
                    "The number of labels must match the number of slides. "
                    f"Got {len(labels)} labels and {len(slides)} slides."
                )
                logger.exception(msg)
                raise ValueError(msg)
            # Allow either per-patch arrays or scalar labels per slide for pooling tasks.
            if not all(isinstance(v, (np.ndarray, float, int)) for v in labels.values()):
                msg = "All values in the labels dictionary must be numpy arrays or scalars (float/int)."
                logger.exception(msg)
                raise ValueError(msg)
            label_dict = labels
        else:
            msg = "Labels must be a numpy array or a dictionary."
            logger.exception(msg)
            raise TypeError(msg)

        self.slides = slides
        self.label_dict = label_dict
        if feature_type not in ["patch_feature_embeddings", "patch_zeroshot_embeddings"]:
            msg = (
                "feature_type must be one of 'patch_feature_embeddings' or "
                "'patch_zeroshot_embeddings'."
            )
            logger.exception(msg)
            raise ValueError(msg)
        self.feature_type = feature_type
        self.model_name = model_name

    def _get_features(self, slide: Slide) -> torch.Tensor:
        """Get the features for a given slide.

        Args:
            slide (Slide): A Slide instance.

        Raises:
            ValueError: If the feature type is not one of "patch_feature_embeddings" or
                "patch_zeroshot_embeddings".

        Returns:
            torch.Tensor: A tensor of features.

        """
        if self.feature_type == "patch_feature_embeddings":
            kind = SK.TILE_FEATURE_EMBEDDINGS
        elif self.feature_type == "patch_zeroshot_embeddings":
            kind = SK.TILE_ZEROSHOT_EMBEDDINGS
        else:
            msg = f"Invalid feature type: {self.feature_type}"
            logger.exception(msg)
            raise ValueError(msg)

        emb_np = slide.store.read_embeddings(
            slide_id=slide.id_,
            model_id=self.model_name,
            kind=kind,
        )
        return torch.tensor(emb_np, dtype=torch.float32)

    def __len__(self) -> int:
        """Get the number of slides in the dataset."""
        return len(self.slides)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        """Get the features and label for a given slide index."""
        slide = self.slides[idx]
        slide_id = slide.id_

        # Load features from h5 file
        features = self._get_features(slide)

        # Get slide label
        label = self.label_dict[slide_id]
        if isinstance(label, np.ndarray):
            label = torch.tensor(label, dtype=torch.float32)
        else:
            label = torch.tensor([float(label)], dtype=torch.float32)

        return features, label


class TileDataset(torch.utils.data.Dataset):
    """A PyTorch Dataset for loading patches from HDF5 files with multiprocessing support.

    This dataset is designed to work efficiently with DataLoader when num_workers > 0 by
    opening HDF5 files on-demand in each worker process to avoid sharing file handles
    across processes.
    """

    def __init__(
        self,
        tile_index: TileIndex,
        store: SlideDataStore,
        slide_id: str,
        transform: Callable[[Image.Image], Any] | None = None,
    ) -> None:
        """Initialize a dataset that streams slide patches from the store.

        Args:
            tile_index: Geometry and indexing for the slide's tiles (foreground only).
            store: SlideDataStore instance that manages per-slide HDF5 files.
            slide_id: Identifier for the slide whose patches are read.
            transform: Optional per-item transform applied to each PIL Image
                (return value may be a tensor or dict of tensors).

        """
        self.tile_index = tile_index
        self.store = store
        self.slide_id = slide_id
        self.n_foreground_tiles = self.tile_index.n_foreground_tiles
        self.transform = transform
        # Lazily opened per-worker HDF5 handles. Do not open here to avoid
        # sharing across processes when DataLoader(num_workers>0) forks.
        self._h5: h5py.File | None = None
        self._tiles_ds: h5py.Dataset | None = None

    # Keep file handles out of pickled state so each worker opens its own.
    def __getstate__(self) -> dict[str, Any]:
        state = self.__dict__.copy()
        state["_h5"] = None
        state["_tiles_ds"] = None
        return state

    def __setstate__(self, state: dict[str, Any]) -> None:  # noqa: D401
        # Restore state; handles remain None and will be opened lazily.
        self.__dict__.update(state)

    def _ensure_open(self) -> None:
        """Open the HDF5 file/dataset if not already open. Safe to call repeatedly."""
        if self._tiles_ds is not None:
            return
        # Open file within the worker context to avoid invalid handles.
        h5_path = self.store.get_slide_h5_path(self.slide_id)
        self._h5 = h5py.File(h5_path, mode="r", libver="latest")
        self._tiles_ds = self._h5[SK.DS_TILES]

    def close(self) -> None:
        """Close any open HDF5 handles.

        Note: This is safe to call multiple times.
        """
        try:
            if self._h5 is not None:
                self._h5.close()
        finally:
            self._h5 = None
            self._tiles_ds = None

    def __del__(self) -> None:  # pragma: no cover - best-effort cleanup
        try:
            self.close()
        except Exception:
            # Avoid raising during GC
            pass

    def __len__(self) -> int:
        """Return the total number of patches in the dataset."""
        return self.n_foreground_tiles

    def __getitem__(self, idx: int) -> Any:
        """Get a patch at the specified index.

        Args:
            idx (int): Index of the patch to retrieve.

        Returns:
            Any: The patch as a PIL Image (if no transform) or the transformed output
                (torch.Tensor or dict[str, torch.Tensor]).

        """
        # Lazily open per-worker persistent handles to avoid per-item open/close overhead.
        self._ensure_open()
        assert self._tiles_ds is not None  # For type checkers
        patch_pixels = self._tiles_ds[idx]
        pil_image = Image.fromarray(patch_pixels)

        if self.transform:
            return self.transform(pil_image)

        return pil_image


class TransformChain:
    """Pickle-friendly chain of per-item transforms.

    Each step should be a callable that accepts and returns a single item. Steps are executed
    sequentially in the order provided.
    """

    def __init__(self, steps: list[Callable[[Any], Any]]):
        self.steps = steps

    def __call__(self, x: Any) -> Any:  # noqa: ANN401
        for step in self.steps:
            x = step(x)
        return x


def cv2_resize_pil(img: Image.Image, target_size: tuple[int, int]) -> Image.Image:
    """Resize a PIL image using OpenCV with quality-aware interpolation.

    Uses INTER_AREA for downscale and INTER_CUBIC for upscale. Assumes RGB input and returns RGB.
    """
    if target_size[0] <= 0 or target_size[1] <= 0:
        msg = f"Invalid target_size: {target_size}"
        logger.exception(msg)
        raise ValueError(msg)

    np_img = np.asarray(img)
    h, w = np_img.shape[:2]
    # Assume square resize for patches; use width to decide scale
    scale = float(target_size[0]) / float(w)
    interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_CUBIC
    resized = cv2.resize(np_img, dsize=target_size, interpolation=interp)
    return Image.fromarray(resized)


def hf_apply_processor(processor: Any, img: Image.Image) -> dict[str, torch.Tensor]:  # noqa: ANN401
    """Apply a HuggingFace image processor to a single image, returning dict of tensors.

    Squeezes the leading batch dimension added by the processor.
    """
    inputs = processor(img, return_tensors="pt")
    # Squeeze batch dim for each tensor value
    return {k: v.squeeze(0) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}


def collate_dict_or_tensor(batch: list[Any]) -> Any:  # noqa: ANN401
    """Collate a list of tensors or dicts of tensors into a batch.

    - If items are torch.Tensor: stacks along dim 0
    - If items are dict[str, torch.Tensor]: merges keys and stacks per-key
    - Otherwise: raises TypeError
    """
    if len(batch) == 0:
        return batch

    first = batch[0]
    if isinstance(first, torch.Tensor):
        return torch.stack(batch, dim=0)

    if isinstance(first, dict):
        keys = set(first.keys())
        for item in batch[1:]:
            if not isinstance(item, dict) or set(item.keys()) != keys:
                msg = "Inconsistent dict items in batch for collation"
                logger.exception(msg)
                raise TypeError(msg)
        stacked: dict[str, torch.Tensor] = {}
        for k in keys:
            vals = [it[k] for it in batch]
            if not all(isinstance(v, torch.Tensor) for v in vals):
                msg = f"Non-tensor values found for key '{k}' during collation"
                logger.exception(msg)
                raise TypeError(msg)
            stacked[k] = torch.stack(vals, dim=0)
        return stacked

    msg = f"Unsupported batch item type for collation: {type(first)}"
    logger.exception(msg)
    raise TypeError(msg)
