"""Slide grouping tools."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Literal

import numpy as np

from pathfmtools.io.schema import StoreKeys as SK
from pathfmtools.utils.torch import EmbeddingDataset

if TYPE_CHECKING:
    from collections.abc import Iterator

    from pathfmtools.image import Slide

logger = logging.getLogger(__name__)


class SlideGroup:
    """A collection of slides that are processed together."""

    def __init__(
        self,
        slide_list: list[Slide],
    ) -> None:
        """Initialize a SlideGroup.

        Args:
            slide_list: A list of Slide instances to be processed together.

        """
        self.slide_list = slide_list
        self.slide_dict = {slide.id_: slide for slide in slide_list}
        # _concat_embedding_arrs stores concatenated embeddings loaded from all slides. For
        # example, the value under _concat_embedding_arrs["patch_feature_embeddings"][M] is a
        # 2D numpy array with shape (N, D) where N is the total number of patches across all
        # slides (if storing patch-level embeddings) or the total number of slides (if storing
        # slide-level embeddings) and D is the embedding dimension produced by model M. If working
        # with patch embeddings, the embeddings corresponding to the patches for a given slide are
        # stored in a contiguous block of the array, and the first and last index of the block for
        # a given slide are stored in _concat_embedding_idx_ranges to facilitate slicing.
        self._concat_embedding_arrs: dict[
            Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"],
            dict[str, np.ndarray],
        ] = {
            "patch_feature_embeddings": {},
            "patch_zeroshot_embeddings": {},
        }
        # _concat_embedding_idx_ranges stores the first and last index of the block of embeddings
        # for a given slide in _concat_embedding_arrs. For example, the value under
        # _concat_embedding_idx_ranges["patch_feature_embeddings"]["S1"] is a tuple of the form
        # (start_idx, end_idx) indicating the range of indices in _concat_embedding_arrs that
        # correspond to the patch feature embeddings for slide S1.
        self._concat_embedding_idx_ranges: dict[
            Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"],
            dict[str, tuple[int, int]],
        ] = {
            "patch_feature_embeddings": {},
            "patch_zeroshot_embeddings": {},
        }

    def __len__(self) -> int:
        """Get the number of slides in the SlideGroup.

        Returns:
            int: The number of slides in the SlideGroup.

        """
        return len(self.slide_dict)

    def __iter__(self) -> Iterator[Slide]:
        """Iterate over the slides in the SlideGroup.

        Returns:
            Iterator[Slide]: An iterator over the slides in the SlideGroup.

        """
        return iter(self.slide_dict.values())

    def __getitem__(self, slide_id: str) -> Slide:
        """Get a slide from the SlideGroup by its ID.

        Args:
            slide_id (str): The ID of the slide to get.

        Returns:
            Slide: The slide with the given ID.

        """
        return self.slide_dict[slide_id]

    def get_slide_ids(self) -> list[str]:
        """Get the IDs of the slides in the SlideGroup.

        Returns:
            list[str]: The IDs of the slides in the SlideGroup.

        """
        return list(self.slide_dict.keys())

    def create_torch_dataset(
        self,
        labels: dict[str, float | np.ndarray],
        embedding_type: Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"],
        model_name: str,
    ) -> EmbeddingDataset:
        """Create a torch Dataset from the SlideGroup.

        Args:
            labels (dict[str, float | np.ndarray]): A dictionary with keys that are slide IDs and
                values that are numpy arrays of labels or floats.
            embedding_type (Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"]): The
                type of features to use.
            model_name (str): The name of the model to use for feature extraction.

        Returns:
            SlideDataset: A torch Dataset from the SlideGroup.

        """
        return EmbeddingDataset(
            slides=self.slide_list,
            labels=labels,
            feature_type=embedding_type,
            model_name=model_name,
        )

    def get_concatenated_embedding_array(
        self,
        model_name: str,
        embedding_type: Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"],
    ) -> np.ndarray:
        """Get the concatenated embedding array."""
        if model_name not in self._concat_embedding_arrs[embedding_type]:
            self._load_and_concatenate_embeddings(model_name, embedding_type)
        return self._concat_embedding_arrs[embedding_type][model_name]

    def _load_and_concatenate_embeddings(
        self,
        model_name: str,
        embedding_type: Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"],
    ) -> None:
        """Load the patch embeddings for all slides into a single array.

        Args:
            model_name (str): The name of the model to use for feature extraction.
            embedding_type (Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"]): The
                type of embeddings to load.

        Raises:
            ValueError: If the patch embedding dimensions are not the same for all slides.

        """
        # If slides have different patch embedding dimensions, they cannot be concatenated.
        # The most likely cause for different dimensionality is that different patch embedding
        # models were used to generate the embeddings.
        if embedding_type == "patch_feature_embeddings":
            embedding_dims = [
                slide.store.read_embeddings(slide.id_, model_name, SK.TILE_FEATURE_EMBEDDINGS).shape[1]  # type: ignore[index]
                for slide in self
            ]
        elif embedding_type == "patch_zeroshot_embeddings":
            embedding_dims = [
                slide.store.read_embeddings(slide.id_, model_name, SK.TILE_ZEROSHOT_EMBEDDINGS).shape[1]  # type: ignore[index]
                for slide in self
            ]
        else:
            msg = f"Invalid embedding type: {embedding_type}"
            logger.exception(msg)
            raise ValueError(msg)

        if len(set(embedding_dims)) != 1:
            msg = "All slides must have the same embedding dimension."
            raise ValueError(msg)
        embedding_dim = embedding_dims[0]
        n_patches = sum(slide.tile_index.n_foreground_tiles for slide in self)
        embedding_arr = np.empty((n_patches, embedding_dim))
        i = 0
        for slide_id, slide in self.slide_dict.items():
            slide_n_patches = slide.tile_index.n_foreground_tiles
            # Store the embeddings retrieved from each slide in a contiguous block of the
            # concatenated embedding array. Also store the first and last index of the embeddings
            # retrieved from each slide in the embedding array. These indices are used to slice the
            # array when retrieving the embeddings corresponding to a specific slide. If multiple
            # types of embeddings are used, the embedding index ranges will still be the same, as
            # the number of patches is the same regardless of the embedding type.
            if slide_id in self._concat_embedding_idx_ranges[embedding_type]:
                if self._concat_embedding_idx_ranges[embedding_type][slide_id] != (
                    i,
                    i + slide_n_patches,
                ):
                    msg = (
                        f"Slide {slide_id} has a different number of patches than the other slides."
                    )
                    raise ValueError(msg)
            else:
                self._concat_embedding_idx_ranges[embedding_type][slide_id] = (
                    i,
                    i + slide_n_patches,
                )
            if embedding_type == "patch_feature_embeddings":
                embedding_arr[i : i + slide_n_patches] = slide.store.read_embeddings(
                    slide.id_, model_name, SK.TILE_FEATURE_EMBEDDINGS,
                )
            elif embedding_type == "patch_zeroshot_embeddings":
                embedding_arr[i : i + slide_n_patches] = slide.store.read_embeddings(
                    slide.id_, model_name, SK.TILE_ZEROSHOT_EMBEDDINGS,
                )
            else:
                msg = f"Invalid embedding type: {embedding_type}"
                logger.exception(msg)
                raise ValueError(msg)
            i += slide_n_patches

        self._concat_embedding_arrs[embedding_type][model_name] = embedding_arr

    def map_embedding_idx_to_source_patches(
        self,
        embedding_idx: int,
        embedding_type: Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"],
    ) -> tuple[Slide, int]:
        """Map an embedding in the concatenated embedding array to the source slide and patch index.

        Args:
            embedding_idx (int): The index of the embedding to map.
            embedding_type (Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"]): The
                type of embeddings to map.

        Raises:
            ValueError: If the embedding index is not found in any slide.

        Returns:
            tuple[Slide, int]: The source slide and patch index.

        """
        for slide_id, embedding_idx_range in self._concat_embedding_idx_ranges[
            embedding_type
        ].items():
            if embedding_idx_range[0] <= embedding_idx < embedding_idx_range[1]:
                return self.slide_dict[slide_id], embedding_idx - embedding_idx_range[0]
        msg = f"Embedding index {embedding_idx} not found in any slide."
        logger.exception(msg)
        raise ValueError(msg)

    def map_vals_to_source_patches(
        self,
        vals: np.ndarray,
        embedding_type: Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"],
    ) -> dict[str, np.ndarray]:
        """Map values ordered by concatenated embedding index to the source patches.

        Concretely, given a concatenated embedding array with shape (N, D) where N is the total
        number of patches across all slides and D is the embedding dimension, vals is expected to
        have shape (N,). The values in vals are assumed to be ordered by the concatenated embedding
        index. This method then maps the values in vals to the source patches, grouping them by
        slide and returning a dictionary with slide IDs as keys and numpy arrays of values (ordered
        by patch index) as values.

        Args:
            vals (np.ndarray): A numpy array of values to map.
            embedding_type (Literal["patch_feature_embeddings", "patch_zeroshot_embeddings"]): The
                type of embeddings to map.

        Returns:
            dict[str, np.ndarray]: A dictionary of slide IDs and numpy arrays of values.

        """
        slide_vals = {}
        for slide_id in self.get_slide_ids():
            idx_range = self._concat_embedding_idx_ranges[embedding_type][slide_id]
            slide_vals[slide_id] = vals[idx_range[0] : idx_range[1]]

        return slide_vals
