"""Thin reader over OpenSlide for thumbnails and regions.

Provides simple, read-only helpers to fetch slide dimensions, thumbnails,
arbitrary regions, magnification, and to iterate patches in row-major order.
"""

from __future__ import annotations

import gc
import itertools
import logging
from typing import TYPE_CHECKING

import numpy as np
import openslide
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from collections.abc import Iterator
    from pathlib import Path


class SlideReader:
    """Wrapper around OpenSlide for basic read operations.

    Notes:
        - This class opens the slide on initialization and keeps a handle for
          subsequent reads. It does not manage write operations.
        - All image arrays are returned as RGB (channels-last) uint8.

    """

    DEFAULT_MAGNIFICATION = 20

    def __init__(self, slide_path: Path, magnification: int | None = None) -> None:
        """Initialize a SlideReader.

        Args:
            slide_path: Path to the whole-slide image.
            magnification: Optional magnification override. If None, attempts to
                read from slide properties and falls back to DEFAULT_MAGNIFICATION.

        Side Effects:
            - Opens the slide via OpenSlide and caches width/height/magnification.

        """
        self.slide_path = slide_path
        self._specified_magnification = magnification

        self._slide: openslide.OpenSlide
        self._open()
        self._width = self._slide.dimensions[0]
        self._height = self._slide.dimensions[1]
        self._magnification = self.get_slide_magnification()

    @property
    def slide(self) -> openslide.OpenSlide:
        """Underlying OpenSlide handle (read-only)."""
        return self._slide

    @property
    def width(self) -> int:
        """Full-resolution slide width in pixels (level 0)."""
        if self._width is None:
            self._width = self._slide.dimensions[0]
        return self._width

    @property
    def height(self) -> int:
        """Full-resolution slide height in pixels (level 0)."""
        if self._height is None:
            self._height = self._slide.dimensions[1]
        return self._height

    @property
    def magnification(self) -> int:
        """Magnification used for this reader (resolved or user-specified)."""
        return self._magnification

    def get_thumbnail(self, size: tuple[int, int]) -> np.ndarray:
        """Return a thumbnail as an RGB array.

        Args:
            size: Requested thumbnail size as (width, height) in pixels.

        Returns:
            Numpy array of shape (H, W, 3), dtype uint8 (RGB).

        """
        im = self._slide.get_thumbnail(size)
        return np.array(im)[:, :, :3]

    def read_region(self, x: int, y: int, width: int, height: int) -> np.ndarray:
        """Read an RGB region at level 0.

        Args:
            x: Top-left X coordinate in pixels.
            y: Top-left Y coordinate in pixels.
            width: Region width in pixels.
            height: Region height in pixels.

        Returns:
            Numpy array of shape (H, W, 3), dtype uint8 (RGB).

        """
        im = self._slide.read_region((x, y), level=0, size=(width, height))
        return np.array(im)[:, :, :3]

    def get_dimensions(self) -> tuple[int, int]:
        """Return (width, height) at level 0 in pixels."""
        return self._slide.dimensions

    def get_slide_magnification(self) -> int:
        """Resolve slide magnification with sensible fallbacks.

        Resolution order:
        1) Use the user-specified magnification if provided at init.
        2) Otherwise, read the ``openslide.objective-power`` slide property.
        3) If unavailable, fall back to ``DEFAULT_MAGNIFICATION`` and log a warning.

        Returns:
            Magnification as an integer.

        """
        try:
            properties_magnification = int(self.slide.properties["openslide.objective-power"])
        except KeyError:
            msg = (
                f"Unable to find magnification level in slide properties for slide "
                f"{self.slide_path}. Using default magnification of {self.DEFAULT_MAGNIFICATION}."
            )
            logger.warning(msg)
            properties_magnification = None

        if self._specified_magnification is None:
            if properties_magnification is not None:
                magnification = properties_magnification
            else:
                msg = (
                    f"No magnification level specified for slide {self.slide_path}. Using "
                    f"default magnification of {self.DEFAULT_MAGNIFICATION}."
                )
                logger.warning(msg)
                magnification = self.DEFAULT_MAGNIFICATION
        else:
            if (properties_magnification is not None) and (
                properties_magnification != self._specified_magnification
            ):
                msg = (
                    f"Magnification level specified during Slide initialization "
                    f"({self._specified_magnification}) does not match magnification level found in"
                    f" slide properties ({properties_magnification}). Believing user-specified "
                    f"magnification level ({self._specified_magnification})."
                )
                logger.warning(msg)
            magnification = self._specified_magnification

        return int(magnification)

    def _open(self) -> None:
        """Open the slide using OpenSlide."""
        self._slide = openslide.OpenSlide(str(self.slide_path))

    def iter_patches_row_major(
        self,
        patch_size: int,
        cropped_width: int,
        cropped_height: int,
        segmentation_mask: np.ndarray | None = None,
        desc: str = "Iterating through slide patches",
        show_progress: bool = False,
    ) -> Iterator[tuple[np.ndarray, int, int]]:
        """Yield patches in row-major order as they are read.

        Args:
            patch_size: Square patch size in pixels.
            cropped_width: Width of the cropped area to iterate.
            cropped_height: Height of the cropped area to iterate.
            segmentation_mask: Optional boolean mask (rows x cols) indicating
                which patch locations are foreground. If None, iterates all.
            desc: Progress description when ``show_progress`` is True.
            show_progress: If True, display a progress bar via Rich.

        Yields:
            Tuples of (patch_rgb, x_top_left, y_top_left), where ``patch_rgb`` is
            an ``H x W x 3`` uint8 array.

        """
        n_rows = cropped_height // patch_size
        n_cols = cropped_width // patch_size
        patch_idx_iterator = itertools.product(range(n_rows), range(n_cols))

        n_patches = (
            (n_rows * n_cols) if segmentation_mask is None else int(np.sum(segmentation_mask))
        )

        def _iter(update_cb):
            for patch_idx, (patch_row, patch_col) in enumerate(patch_idx_iterator):
                if (segmentation_mask is not None) and (
                    not segmentation_mask[patch_row, patch_col]
                ):
                    continue
                y_top_left = patch_row * patch_size
                x_top_left = patch_col * patch_size
                patch_arr = np.array(
                    self.slide.read_region(
                        (x_top_left, y_top_left),
                        0,
                        (patch_size, patch_size),
                    ),
                )[:, :, :3]  # OpenSlide returns a 4-channel image (RGBA), only keep RGB
                update_cb()
                yield patch_arr, x_top_left, y_top_left
                del patch_arr
                if patch_idx % 100 == 0:
                    gc.collect()
            gc.collect()

        if show_progress:
            with Progress(
                TextColumn("[bold blue]{task.description}", justify="right"),
                BarColumn(bar_width=None, complete_style="cyan", finished_style="cyan"),
                "[progress.percentage]{task.percentage:>3.1f}%",
                "•",
                TimeRemainingColumn(),
            ) as progress:
                task = progress.add_task(desc, total=n_patches)
                yield from _iter(lambda: progress.update(task, advance=1))
        else:
            yield from _iter(lambda: None)
