"""Image segmentation tools."""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

import numpy as np
from skimage import filters as skfilters

if TYPE_CHECKING:
    from pathfmtools.image.tile_index import TileIndex
    from pathfmtools.io.slide_reader import SlideReader

logger = logging.getLogger(__name__)


def get_segmentation_method(segmenter: str) -> type[SlideSegmenter]:
    """Return a SlideSegmenter child class based on the given name.

    Args:
        segmenter (str): The name of the segmentation method.

    Returns:
        type[SlideSegmenter]: The corresponding SlideSegmenter child class.

    """
    if segmenter.lower() == "otsu":
        return OtsuSegmenter
    msg = f"Invalid segmentation method: {segmenter}"
    logger.error(msg)
    raise ValueError(msg)


class SlideSegmenter(ABC):
    """Abstract base class for slide segmenters."""

    @classmethod
    @abstractmethod
    def create_patch_segmentation_mask(
        cls,
        tile_index: TileIndex,
        slide_reader: SlideReader,
        **kwargs: Any,
    ) -> np.ndarray:
        """Create a patch segmentation mask.

        Args:
            im (np.ndarray): Image to segment.
            patch_size (int): Side length of the square patch to extract.
            im_height (int): Height of the image to segment.
            im_width (int): Width of the image to segment.
            patch_iterator (Iterator[np.ndarray]): Iterator over patches to segment.
            **kwargs: Additional keyword arguments to pass to the segmenter.

        Returns:
            np.ndarray: Binary segmentation mask with one entry per patch.
                A value of 1 indicates foreground, 0 indicates background.

        """


class OtsuSegmenter(SlideSegmenter):
    """Segment a slide into tissue and background regions using Otsu's method."""

    @classmethod
    def create_patch_segmentation_mask(
        cls,
        tile_index: TileIndex,
        slide_reader: SlideReader,
        **kwargs: Any,
    ) -> np.ndarray:
        """Create a per-tile binary segmentation mask via Otsu on grayscale means.

        Computes grayscale mean for every tile in the tile grid and thresholds
        with Otsu to mark foreground (1) vs background (0).

        Args:
            slide_width: Full-resolution slide width (unused; validated against `tile_index`).
            slide_height: Full-resolution slide height (unused; validated against `tile_index`).
            tile_index: Immutable tile geometry/mappings for the slide.
            slide_reader: Reader capable of streaming tile-sized regions.
            **kwargs: Optional parameters.
                - patch_size (int, optional): Expected tile size. Used for validation.

        Returns:
            np.ndarray: Integer mask of shape (R, C) with 1 for foreground, 0 for background.

        """
        tile_size = tile_index.tile_size

        # Compute grayscale mean per tile efficiently using channel means
        n_rows = tile_index.n_tile_rows
        n_cols = tile_index.n_tile_cols
        gray_means = np.empty((n_rows, n_cols), dtype=np.float32)

        # Crop to tile-aligned region
        cropped_w = n_cols * tile_size
        cropped_h = n_rows * tile_size

        show_progress = bool(kwargs.get("show_progress", False))
        for patch, x_top_left, y_top_left in slide_reader.iter_patches_row_major(
            patch_size=tile_size,
            cropped_width=cropped_w,
            cropped_height=cropped_h,
            segmentation_mask=None,
            desc="Computing per-tile grayscale means",
            show_progress=show_progress,
        ):
            r = y_top_left // tile_size
            c = x_top_left // tile_size
            ch_means = patch.mean(axis=(0, 1), dtype=np.float64)
            gray_means[r, c] = 0.2125 * ch_means[0] + 0.7154 * ch_means[1] + 0.0721 * ch_means[2]

        thresh = skfilters.threshold_otsu(gray_means)

        return (gray_means < thresh).astype(np.int32)
