"""Slide processing tools."""

from __future__ import annotations

import logging
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
from PIL import Image
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn

from pathfmtools.image.segmentation import get_segmentation_method
from pathfmtools.io.schema import StoreKeys as SK
from pathfmtools.io.slide_data_store import SlideDataStore
from pathfmtools.io.slide_reader import SlideReader

from .tile_index import TileIndex, tile_index_from_slide_store

if TYPE_CHECKING:
    import torch

    from pathfmtools.analysis.zeroshot_classification import ZeroShotPatchClassifier
    from pathfmtools.embedding_models.embedding_model import EmbeddingModel

logger = logging.getLogger(__name__)

Image.MAX_IMAGE_PIXELS = None


class Slide:
    """Slide processing class."""

    def __init__(
        self,
        slide_path: Path,
        store_root: Path | None = None,
        magnification: int | None = None,
        suppress_logs: bool = False,
    ) -> None:
        """Initialize Slide object.

        One of `slide_fpath` or `h5_fpath` must be provided, otherwise the Slide object will have
        no data to reference. If both are provided, the associated SlideIO object will attempt to
        use cached data where possible to speed up processing.

        Args:
            slide_id (str): The slide's unique identifier.
            slide_fpath (Path | None, optional): The path to the raw slide file (e.g. TIFF, NDPI).
                Defaults to None.
            h5_fpath (Path | None, optional): The path to the exiting HDF5 file for the slide which
                was produced by interacting with Slide and/or SlideIO. This file contains processed
                slide data (e.g. extracted patches, computed embeddings). Defaults to None.
            magnification (int | None, optional): A user-specified magnification level for the
                slide. If not provided, the magnification level will be inferred from the slide file
                (if available), or will be set to the default magnification level. Defaults to None.
            out_dir (Path | None, optional): The directory to which the slide's h5 file will be
                written (if an existing h5 file is not provided in the `h5_fpath` argument). This
                is the file that will contain processed slide data, such as extracted patches and
                computed embeddings. If an existing h5 file is specified, this argument has no
                effect. If no h5 file is specified and this argument is not provided, the h5 file
                will be written to a temporary directory and deleted after the Slide object is
                deleted. Defaults to None.

        """
        self.id_ = slide_path.stem
        self.slide_reader = SlideReader(slide_path, magnification=magnification)
        if store_root is not None:
            self.store_root = store_root
        else:
            self.store_root = Path(tempfile.mkdtemp())
        self.suppress_logs = suppress_logs
        self.store = SlideDataStore(self.store_root)
        self._tile_index = None

        self.preprocessed = self.store.check_tiles_present(self.id_)

    @property
    def tile_index(self) -> TileIndex:
        """Get the tile index for the slide."""
        if not self.preprocessed:
            msg = "Slide has not been preprocessed."
            raise RuntimeError(msg)
        if self._tile_index is None:
            self._tile_index = tile_index_from_slide_store(self.store, self.id_)
        return self._tile_index

    def preprocess(
        self,
        patch_size: int,
        segmenter: str = "otsu",
        *,
        verbose: bool = True,
        **segmenter_kwargs,
    ) -> None:
        """Preprocess the slide: build tile grid, segment, and persist tiles.

        - Uses TileIndex to define the grid over a cropped area that is
          divisible by patch_size.
        - Streams tiles using SlideReader and writes tiles, metadata, and
          segmentation to the per-slide store.
        """
        if self.preprocessed:
            logger.info("Slide %s already preprocessed; skipping.", self.id_)
            return

        if patch_size <= 0:
            msg = "patch_size must be positive"
            raise ValueError(msg)

        # Initialize pre-segmentation tile index over the full grid
        pre_ti = TileIndex(
            tile_size=patch_size,
            slide_w=self.slide_reader.width,
            slide_h=self.slide_reader.height,
        )

        seg_cls = get_segmentation_method(segmenter)
        logger.info("Segmenting slide using method %s (%s)", segmenter, seg_cls.__name__)
        # Propagate progress preference to segmenter if supported.
        segmenter_kwargs = {"show_progress": verbose, **segmenter_kwargs}
        seg_mask = seg_cls.create_patch_segmentation_mask(
            tile_index=pre_ti,
            slide_reader=self.slide_reader,
            **segmenter_kwargs,
        ).astype(bool)

        prop_foreground = float(np.mean(seg_mask))
        logger.info(
            "Kept %s%% of patches (%s patches) as foreground",
            format(prop_foreground * 100, ".2f"),
            format(int(seg_mask.sum()), ","),
        )

        self.store.write_slide_metadata(
            self.id_,
            slide_width=self.slide_reader.width,
            slide_height=self.slide_reader.height,
            magnification=int(self.slide_reader.magnification),
            patch_size=patch_size,
            slide_fpath=self.slide_reader.slide_path,
            segmentation_method=segmenter,
            prop_foreground=prop_foreground,
        )

        # Stream tiles to the store with segmentation mask

        with self.store.get_tile_data_writer(
            slide_id=self.id_,
            patch_size=patch_size,
            seg_mask=seg_mask,
            level=0,
            magnification=int(self.slide_reader.magnification),
        ) as writer:
            # With no segmentation mask, TileIndex considers all tiles as foreground
            total_kept = int(seg_mask.sum())
            if verbose:
                with Progress(
                    TextColumn("[bold blue]{task.description}", justify="right"),
                    BarColumn(bar_width=None, complete_style="cyan", finished_style="cyan"),
                    "[progress.percentage]{task.percentage:>3.1f}%",
                    "•",
                    TimeRemainingColumn(),
                ) as progress:
                    task = progress.add_task(
                        f"Writing tiles for {self.id_}",
                        total=total_kept or None,
                    )
                    for r, c, x, y in pre_ti.iter_tiles_row_major(foreground_only=True):
                        if not seg_mask[r, c]:
                            continue
                        patch = self.slide_reader.read_region(x, y, patch_size, patch_size)
                        writer.write_batch(
                            tiles=patch[np.newaxis, ...],
                            rows=np.array([r], dtype=np.int32),
                            cols=np.array([c], dtype=np.int32),
                            x_px=np.array([x], dtype=np.int32),
                            y_px=np.array([y], dtype=np.int32),
                        )
                        progress.update(task, advance=1)
            else:
                for r, c, x, y in pre_ti.iter_tiles_row_major(foreground_only=True):
                    if not seg_mask[r, c]:
                        continue
                    patch = self.slide_reader.read_region(x, y, patch_size, patch_size)
                    writer.write_batch(
                        tiles=patch[np.newaxis, ...],
                        rows=np.array([r], dtype=np.int32),
                        cols=np.array([c], dtype=np.int32),
                        x_px=np.array([x], dtype=np.int32),
                        y_px=np.array([y], dtype=np.int32),
                    )

        self.preprocessed = True
        self._tile_index = tile_index_from_slide_store(self.store, self.id_)

    def embed_tiles(
        self,
        model: EmbeddingModel,
        batch_size: int,
        auto_rescale: bool = True,
        skip_feature_embeddings: bool = False,
        skip_zeroshot_embeddings: bool = False,
        verbose: bool = True,
        num_workers: int | None = None,
        pin_memory: bool = True,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Embed the patches using an EmbeddingModel subclass which supports patch-level embeddings.

        If supported by the model (e.g. CONCH, MUSK), zero-shot embeddings which are aligned with
        language are also computed for each patch.

        Args:
            model (EmbeddingModel): The patch embedding model to use.
            batch_size (int): The number of patches to embed in each batch.
            auto_rescale (bool, optional): Whether to automatically rescale the patches to match
                the model's expected receptive field. If False, the patch will be resized to the
                model's expected patch size if the receptive field is incompatible. If True, the
                patch will be resized to the model's expected patch size if the receptive field is
                incompatible. Defaults to True.
            skip_feature_embeddings (bool, optional): Whether to skip generating and saving feature
                embeddings. Defaults to False.
            skip_zeroshot_embeddings (bool, optional): Whether to skip generating and saving
                zero-shot embeddings. Defaults to False.

        Raises:
            RuntimeError: If the slide has not been preprocessed.

        Returns:
            tuple[torch.Tensor, torch.Tensor | None]: A tuple containing the patch embeddings and
                the patch zero-shot embeddings (if implemented by the model).

        """
        if not self.preprocessed:
            msg = "Slide has not been preprocessed."
            raise RuntimeError(msg)

        feature_embeddings, zeroshot_embeddings = model.embed_tiles(
            batch_size=batch_size,
            tile_index=self.tile_index,
            slide_reader=self.slide_reader,
            slide_data_store=self.store,
            verbose=verbose,
            skip_feature_embeddings=skip_feature_embeddings,
            skip_zeroshot_embeddings=skip_zeroshot_embeddings,
            auto_rescale=auto_rescale,
            num_workers=num_workers,
            pin_memory=pin_memory,
        )
        if feature_embeddings is not None and not skip_feature_embeddings:
            self.store.write_embeddings(
                model_id=model.NAME,
                kind=SK.TILE_FEATURE_EMBEDDINGS,
                emb=feature_embeddings.detach().cpu().numpy(),
                slide_id=self.id_,
            )
        if zeroshot_embeddings is not None and not skip_zeroshot_embeddings:
            self.store.write_embeddings(
                model_id=model.NAME,
                kind=SK.TILE_ZEROSHOT_EMBEDDINGS,
                emb=zeroshot_embeddings.detach().cpu().numpy(),
                slide_id=self.id_,
            )

        return feature_embeddings, zeroshot_embeddings

    def run_zeroshot_classification(
        self,
        zero_shot_classifier: ZeroShotPatchClassifier,
        model_name: str,
        text_list: list[str],
        device: torch.device,
    ) -> dict[str, np.ndarray]:
        """Run zero-shot classification over this slide's tiles.

        This is a thin wrapper around ``ZeroShotPatchClassifier.classify`` that
        supplies the current ``Slide`` instance so the classifier can read
        zero-shot embeddings from the per-slide store.

        Args:
            zero_shot_classifier: The zero-shot classifier instance.
            model_name: Embedding model name whose zero-shot tile embeddings to use.
            text_list: List of class prompts to score against each tile.
            device: Torch device for computing text embeddings.

        Returns:
            dict[str, np.ndarray]: Mapping from class text to arrays of logits, as
            well as probabilities and predictions keyed under the classifier's
            standard output keys when used directly.

        Raises:
            RuntimeError: If the slide has not been preprocessed yet.
            ValueError: If zero-shot embeddings for the given model are missing.
        """
        if not self.preprocessed:
            msg = "Slide has not been preprocessed."
            raise RuntimeError(msg)

        return zero_shot_classifier.classify(
            model_name=model_name,
            classes=text_list,
            slide=self,
            device=device,
        )

    def delete_tiles_from_disk(self) -> None:
        """Delete the tiles from the HDF5 file."""
        self.store.delete_tiles(self.id_)

    def _mark_preprocessed(self) -> None:
        self.preprocessed = True
        self._tile_index = tile_index_from_slide_store(self.store, self.id_)
