"""Immutable TileIndex core (geometry + mappings).

Supports two initialization modes:
- With a segmentation mask: preserves existing behavior (only foreground tiles kept).
- Without a segmentation mask: constructs a full grid over the largest
  tile-aligned crop inside the slide dimensions (all tiles kept).

Centralizes mapping between extracted tile index space (N), grid coordinates
(R, C), and slide pixel coordinates (x, y) for square, stride-P tiles.
"""

from __future__ import annotations

import contextlib
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, overload

import numpy as np
import numpy.typing as npt

from pathfmtools.io.schema import StoreKeys as SK

if TYPE_CHECKING:
    from collections.abc import Iterator

    from pathfmtools.io.slide_data_store import SlideDataStore

IntArray = npt.NDArray[np.int32]
BoolArray = npt.NDArray[np.bool_]
FloatArray = npt.NDArray[np.float32]

logger = logging.getLogger(__name__)


@dataclass(frozen=True, slots=True, eq=False)
class TileIndex:
    """Canonical immutable geometry/indexing for tiles.

    The instance centralizes conversions among:
    - extracted index space (0..N-1)
    - grid coordinates (row, col) within (R,C)
    - slide pixel coordinates (x, y) for tile top-left anchors

    Attributes are normalized to contiguous arrays with stable dtypes.

    Invariants (validated in `_validate`):
    - seg_mask.sum() == N
    - grid_to_idx[r, c] ∈ [0, N) for kept cells; -1 otherwise
    - xs == cols * P; ys == rows * P
    """

    tile_size: int  # Size of each tile in pixels.
    slide_w: int  # Width of the slide in pixels (full resolution).
    slide_h: int  # Height of the slide in pixels (full resolution).

    # Optional segmentation mask (True=foreground). If None, treat all tiles as foreground.
    seg_mask: BoolArray | None = field(default=None, init=True, repr=False)
    segmentation_mask: BoolArray = field(init=False, repr=False)

    # Derived grid geometry
    n_tile_rows: int = field(init=False, repr=True)
    n_tile_cols: int = field(init=False, repr=True)
    n_foreground_tiles: int = field(init=False, repr=True)
    # Indexing caches (derived)
    fg_tile_idxs: IntArray = field(init=False, repr=False)  # (N,)
    fg_tile_rows: IntArray = field(init=False, repr=False)  # (N,)
    fg_tile_cols: IntArray = field(init=False, repr=False)  # (N,)
    fg_tile_xs: IntArray = field(init=False, repr=False)  # (N,)
    fg_tile_ys: IntArray = field(init=False, repr=False)  # (N,)
    grid_to_idx: IntArray = field(init=False, repr=False)  # (R, C)
    _fg_bbox_rowcol: tuple[int, int, int, int] | None = field(
        init=False,
        repr=False,
        default=None,
    )

    def __post_init__(self) -> None:
        """Standardize inputs and compute caches for mappings."""
        # --- Normalize primitives ---
        tile_size = int(self.tile_size)
        slide_w = int(self.slide_w)
        slide_h = int(self.slide_h)
        if tile_size <= 0:
            msg = "tile_size must be positive"
            raise ValueError(msg)
        if slide_w <= 0 or slide_h <= 0:
            msg = "slide dimensions must be positive"
            raise ValueError(msg)
        object.__setattr__(self, "tile_size", tile_size)
        object.__setattr__(self, "slide_w", slide_w)
        object.__setattr__(self, "slide_h", slide_h)

        # --- Derive tile grid from slide dims (cropped to multiples of tile_size) ---
        n_tile_rows = slide_h // tile_size
        n_tile_cols = slide_w // tile_size
        object.__setattr__(self, "n_tile_rows", int(n_tile_rows))
        object.__setattr__(self, "n_tile_cols", int(n_tile_cols))
        if (n_tile_rows == 0) or (n_tile_cols == 0):
            msg = "Slide too small for at least one tile; ensure slide dims >= tile_size"
            raise ValueError(msg)

        # --- Normalize/construct segmentation mask ---
        if self.seg_mask is None:
            segmentation_mask = np.ones((n_tile_rows, n_tile_cols), dtype=np.bool_)
        else:
            segmentation_mask = np.ascontiguousarray(self.seg_mask, dtype=np.bool_)
            if segmentation_mask.shape != (n_tile_rows, n_tile_cols):
                msg = (
                    f"seg_mask shape {segmentation_mask.shape} != ({n_tile_rows},{n_tile_cols}) "
                    f"derived from slide dims and tile_size"
                )
                raise ValueError(msg)
        object.__setattr__(self, "segmentation_mask", segmentation_mask)

        # --- Compute caches (row-major kept indices) ---
        foreground_tile_idxs = (
            segmentation_mask.reshape(-1).nonzero()[0].astype(np.int32, copy=False)
        )
        n_foreground_tiles = int(foreground_tile_idxs.size)
        object.__setattr__(self, "n_foreground_tiles", n_foreground_tiles)
        object.__setattr__(self, "fg_tile_idxs", np.ascontiguousarray(foreground_tile_idxs))

        fg_rows = (foreground_tile_idxs // n_tile_cols).astype(np.int32, copy=False)
        fg_cols = (foreground_tile_idxs % n_tile_cols).astype(np.int32, copy=False)
        fg_xs = (fg_cols * np.int32(tile_size)).astype(np.int32, copy=False)
        fg_ys = (fg_rows * np.int32(tile_size)).astype(np.int32, copy=False)
        object.__setattr__(self, "fg_tile_rows", np.ascontiguousarray(fg_rows))
        object.__setattr__(self, "fg_tile_cols", np.ascontiguousarray(fg_cols))
        object.__setattr__(self, "fg_tile_xs", np.ascontiguousarray(fg_xs))
        object.__setattr__(self, "fg_tile_ys", np.ascontiguousarray(fg_ys))

        grid_to_idx = np.full((n_tile_rows, n_tile_cols), -1, dtype=np.int32)
        if n_foreground_tiles:
            grid_to_idx[fg_rows, fg_cols] = np.arange(n_foreground_tiles, dtype=np.int32)
        object.__setattr__(self, "grid_to_idx", np.ascontiguousarray(grid_to_idx))

        # Cache foreground bounding box for repeated queries
        if np.any(segmentation_mask):
            rr_all, cc_all = np.where(segmentation_mask)
            r0 = int(rr_all.min())
            r1 = int(rr_all.max()) + 1
            c0 = int(cc_all.min())
            c1 = int(cc_all.max()) + 1
            object.__setattr__(self, "_fg_bbox_rowcol", (r0, r1, c0, c1))
        else:
            object.__setattr__(self, "_fg_bbox_rowcol", None)

        # Make internal arrays read-only to ensure immutability.
        for name in (
            "seg_mask",
            "fg_tile_idxs",
            "fg_tile_rows",
            "fg_tile_cols",
            "fg_tile_xs",
            "fg_tile_ys",
            "grid_to_idx",
        ):
            arr = getattr(self, name)
            with contextlib.suppress(Exception):
                arr.setflags(write=False)

        self._validate()

    def iter_tiles_row_major(
        self,
        *,
        foreground_only: bool = True,
    ) -> Iterator[tuple[int, int, int, int]]:
        """Yield tiles in row-major order as ``(row, col, x, y)`` tuples.

        Args:
            foreground_only: If True, yield only foreground tiles (kept tiles
                from ``seg_mask``). If False, yield all tiles in the grid.

        Yields:
            Tuples of ``(row, col, x, y)`` where ``x, y`` are the tile's
            top-left pixel coordinates.

        """
        if foreground_only:
            rr = self.fg_tile_rows
            cc = self.fg_tile_cols
            xs = self.fg_tile_xs
            ys = self.fg_tile_ys
            n = self.n_foreground_tiles
            for i in range(n):
                yield int(rr[i]), int(cc[i]), int(xs[i]), int(ys[i])
            return

        for row in range(self.n_tile_rows):
            y = row * self.tile_size
            for col in range(self.n_tile_cols):
                x = col * self.tile_size
                yield row, col, x, y

    # --- Mapping methods ---
    def idx_to_rowcol(
        self,
        idx: IntArray,
    ) -> tuple[IntArray, IntArray]:
        """Map extracted tile indices to (rows, cols) in patch space.

        Args:
            idx: np.ndarray[int32], shape (K,)

        Returns:
            (rows, cols): both np.ndarray[int32] of shape (K,)

        """
        idx = np.ascontiguousarray(idx, dtype=np.int32)
        if idx.ndim != 1:
            msg = f"idx must be 1D; got shape {idx.shape}"
            raise ValueError(msg)
        n = self.fg_tile_idxs.size
        if (idx.size > 0) and ((idx.min() < 0) or (idx.max() >= n)):
            msg = f"idx out of range; expected in [0,{n}); got min={idx.min()} max={idx.max()}"
            raise IndexError(msg)
        return self.fg_tile_rows[idx], self.fg_tile_cols[idx]

    def rowcol_to_idx(
        self,
        rows: IntArray,
        cols: IntArray,
    ) -> IntArray:
        """Map (rows, cols) in patch space to extracted tile indices.

        Fails if any (r,c) is background (no tile extracted) or out-of-bounds.
        """
        rows = np.ascontiguousarray(rows, dtype=np.int32)
        cols = np.ascontiguousarray(cols, dtype=np.int32)

        if rows.shape != cols.shape:
            msg = f"rows/cols shape mismatch: {rows.shape} vs {cols.shape}"
            raise ValueError(msg)
        if rows.ndim != 1:
            msg = f"rows/cols must be 1D; got ndim={rows.ndim}"
            raise ValueError(msg)
        if (rows.size > 0) and (
            (rows.min() < 0)
            or (cols.min() < 0)
            or (rows.max() >= self.n_tile_rows)
            or (cols.max() >= self.n_tile_cols)
        ):
            msg = (
                f"row/col out of bounds; rows in [0,{self.n_tile_rows}), cols in "
                f"[0,{self.n_tile_cols})"
            )
            raise IndexError(msg)

        idx = self.grid_to_idx[rows, cols]
        bad = idx < 0  # Background tiles have -1 index
        if np.any(bad):
            first_bad = np.flatnonzero(bad)[:5]
            examples = [f"(r={int(rows[i])},c={int(cols[i])})" for i in first_bad]
            msg = "rowcol_to_idx encountered background tiles at: " + ", ".join(examples)
            raise IndexError(msg)
        return idx.astype(np.int32, copy=False)

    def idx_to_xy(
        self,
        idx: IntArray,
    ) -> tuple[IntArray, IntArray]:
        """Map extracted tile indices to (x,y) tile top-left pixel coordinates."""
        r, c = self.idx_to_rowcol(idx)
        return c * self.tile_size, r * self.tile_size

    def rowcol_to_xy(
        self,
        rows: IntArray,
        cols: IntArray,
    ) -> tuple[IntArray, IntArray]:
        """Map (rows, cols) to (x,y) tile top-left pixel coordinates."""
        rows = np.ascontiguousarray(rows, dtype=np.int32)
        cols = np.ascontiguousarray(cols, dtype=np.int32)
        # Validate using rowcol_to_idx for bounds/background; discard result
        _ = self.rowcol_to_idx(rows, cols)
        return cols * self.tile_size, rows * self.tile_size

    # --- Crop ---
    def foreground_bbox_rowcol(self) -> tuple[int, int, int, int]:
        """Get patch-space bounding box over foreground tiles.

        Returns (r0, r1, c0, c1) with r1/c1 exclusive.
        """
        bbox = getattr(self, "_fg_bbox_rowcol", None)
        if bbox is None:
            msg = "No foreground cells found in seg_mask"
            raise ValueError(msg)
        return bbox

    def foreground_bbox_xy(self) -> tuple[int, int, int, int]:
        """Return pixel space bounding box over foreground tiles.

        Returns:
            (x0, x1, y0, y1): tuple of ints. Bounding box in pixel space (x1/y1 are exclusive).

        """
        r0, r1, c0, c1 = self.foreground_bbox_rowcol()
        return c0 * self.tile_size, c1 * self.tile_size, r0 * self.tile_size, r1 * self.tile_size

    # --- Grid/flat transforms ---
    @overload
    def to_grid(self, values: IntArray) -> tuple[IntArray, BoolArray]: ...
    @overload
    def to_grid(self, values: FloatArray) -> FloatArray: ...
    def to_grid(self, values):
        """Scatter per-tile values (N or NxD) into an (RxC[xD]) grid.

        Background handling:
        - If values.dtype is floating, a single grid is returned.
        - Otherwise, returns (grid, seg_mask).
        """
        values = np.ascontiguousarray(values)
        if values.ndim not in (1, 2):
            msg = f"values must be 1D or 2D; got shape {values.shape}"
            raise ValueError(msg)
        n = self.fg_tile_idxs.size
        if values.shape[0] != n:
            msg = f"values.shape[0] ({values.shape[0]}) != N ({n})"
            raise ValueError(msg)

        out_shape = _make_grid_shape(values, self.n_tile_rows, self.n_tile_cols)
        grid = _make_background_grid(values, out_shape)

        rr = self.fg_tile_rows
        cc = self.fg_tile_cols
        if values.ndim == 1:
            grid[rr, cc] = values
        else:
            grid[rr, cc, :] = values

        if _is_float_dtype(grid):
            return grid

        return grid, self.segmentation_mask.copy()

    def to_flat(self, grid: FloatArray) -> FloatArray:
        """Gather kept-cell values from a grid (RxC[xD]) to a flat (N[,D]).

        Fails if any kept cell is missing.
        """
        grid = np.ascontiguousarray(grid)
        if grid.ndim not in (2, 3):
            msg = f"grid must be 2D or 3D; got shape {grid.shape}"
            raise ValueError(msg)
        if not (grid.shape[0] == self.n_tile_rows and grid.shape[1] == self.n_tile_cols):
            msg = f"grid shape {grid.shape} != (R,C,*), {(self.n_tile_rows, self.n_tile_cols)}"
            raise ValueError(msg)

        rr = self.fg_tile_rows
        cc = self.fg_tile_cols
        # Warn if background and foreground share the same fill value, raise if NaN values are found
        # for foreground tiles
        if _is_float_dtype(grid) and np.isnan(grid[rr, cc, ...]).any():
            msg = "to_flat: NaN values found for foreground tiles"
            raise ValueError(msg)
        self._warn_if_bg_fg_overlap(grid)

        return grid[rr, cc, ...]

    # --- Thumbnail projection ---
    @overload
    def project_to_thumbnail(
        self,
        values: IntArray,
        *,
        thumb_w: int,
        thumb_h: int,
        crop: bool = True,
    ) -> tuple[IntArray, BoolArray]: ...
    @overload
    def project_to_thumbnail(
        self,
        values: FloatArray,
        *,
        thumb_w: int,
        thumb_h: int,
        crop: bool = True,
    ) -> FloatArray: ...
    def project_to_thumbnail(
        self,
        values: IntArray | FloatArray,
        *,
        thumb_w: int,
        thumb_h: int,
        crop: bool = True,
    ) -> tuple[IntArray, BoolArray] | FloatArray:
        """Project per-tile values onto a thumbnail-sized overlay.

        Uses nearest-neighbor tiling: each PxP tile becomes a rectangle of size
        (round(P*sx), round(P*sy)) where sx=thumb_w/slide_w, sy=thumb_h/slide_h.

        If crop=True, restricts to bounding box of foreground rows/cols first.
        Returns overlay (and mask when non-float).
        """
        thumb_w = int(thumb_w)
        thumb_h = int(thumb_h)
        if (thumb_w <= 0) or (thumb_h <= 0):
            msg = "thumb_w and thumb_h must be positive integers"
            raise ValueError(msg)

        sx = thumb_w / float(self.slide_w)
        sy = thumb_h / float(self.slide_h)
        tile_w = max(1, round(self.tile_size * sx))
        tile_h = max(1, round(self.tile_size * sy))

        # Prepare grid (and mask if needed)
        out = self.to_grid(values)
        if isinstance(out, tuple):
            grid, mask = out
        else:
            grid = out
            mask = self.segmentation_mask

        # Crop in grid space if requested
        r0, r1, c0, c1 = (
            self.foreground_bbox_rowcol() if crop else (0, self.n_tile_rows, 0, self.n_tile_cols)
        )
        grid = grid[r0:r1, c0:c1, ...]
        mask_sub = mask[r0:r1, c0:c1]

        # Expand via repeat to thumbnail resolution
        if grid.ndim == 2:
            overlay = np.repeat(np.repeat(grid, tile_h, axis=0), tile_w, axis=1)
            mask_ov = np.repeat(np.repeat(mask_sub, tile_h, axis=0), tile_w, axis=1)
        else:
            # (R,C,D) → expand first two dims
            overlay = np.repeat(np.repeat(grid, tile_h, axis=0), tile_w, axis=1)
            mask_ov = np.repeat(np.repeat(mask_sub, tile_h, axis=0), tile_w, axis=1)

        if _is_float_dtype(values):
            return overlay
        return overlay, mask_ov

    # --- Validation ---
    def _validate(self) -> None:
        """Validate internal invariants.

        Raises:
            ValueError | IndexError: On inconsistent shapes/mappings.
        """
        if self.fg_tile_idxs.shape != (self.n_foreground_tiles,):
            msg = (
                f"fg_tile_idxs shape {self.fg_tile_idxs.shape} != "
                f"(N,)=({(self.n_foreground_tiles,)})"
            )
            raise ValueError(msg)

        if self.n_foreground_tiles > 0:
            vals = self.grid_to_idx[self.fg_tile_rows, self.fg_tile_cols]
            # All kept positions must map to 0..N-1
            if (vals.min() < 0) or (vals.max() >= self.n_foreground_tiles):
                msg = "grid_to_idx has out-of-range values at kept cells"
                raise IndexError(msg)
            # Each extracted index appears exactly once
            hist = np.bincount(vals, minlength=self.n_foreground_tiles)
            if not np.all(hist == 1):
                msg = "grid_to_idx values are not a permutation of [0..N-1]"
                raise ValueError(msg)

        if (self.n_foreground_tiles > 0) and not np.all(
            self.segmentation_mask[self.fg_tile_rows, self.fg_tile_cols],
        ):
            msg = "fg_tile_idxs includes cells not marked True in seg_mask"
            raise ValueError(msg)

    def _warn_if_bg_fg_overlap(self, grid: np.ndarray) -> None:
        """Warn if background and foreground share the same fill value."""
        bg_mask = ~self.segmentation_mask
        # No background, nothing to check
        if not bg_mask.any():
            return

        rr = self.fg_tile_rows
        cc = self.fg_tile_cols
        fg_vals = set(grid[rr, cc])
        bg_vals = set(grid[bg_mask])
        overlap = fg_vals & bg_vals
        if overlap:
            logger.warning("Foreground and background share fill values: %s", overlap)


def tile_index_from_slide_store(store: SlideDataStore, slide_id: str) -> TileIndex:
    """Construct a TileIndex from SlideDataStore datasets.

    Validates consistency between segmentation mask, tile metadata, and slide metadata.
    """
    # Read core datasets
    seg_mask = store.read_seg_mask(slide_id).astype(np.bool_, copy=False)
    tile_meta = store.read_tile_metadata(slide_id)

    tile_rows = tile_meta["rows"].astype(np.int32)
    tile_cols = tile_meta["cols"].astype(np.int32)
    tile_top_left_xs = tile_meta["top_left_xs"].astype(np.int32)
    tile_top_left_ys = tile_meta["top_left_ys"].astype(np.int32)
    tile_widths = tile_meta["widths"].astype(np.int32)
    tile_heights = tile_meta["heights"].astype(np.int32)

    # Basic length check
    n = seg_mask.sum()
    if not (
        len(tile_rows)
        == len(tile_cols)
        == len(tile_top_left_xs)
        == len(tile_top_left_ys)
        == len(tile_widths)
        == len(tile_heights)
        == n
    ):
        msg = "tile metadata column lengths must all equal N"
        raise ValueError(msg)

    # Derive patch size and validate top-left anchors
    tile_size = tile_widths[0]
    if not (np.all(tile_widths == tile_size) and np.all(tile_heights == tile_size)):
        msg = "Tile widths/heights are not consistent or not square"
        logger.error(msg)
        raise ValueError(msg)

    slide_meta = store.read_slide_metadata(slide_id)
    slide_w = int(slide_meta[SK.SLIDE_META_WIDTH])
    slide_h = int(slide_meta[SK.SLIDE_META_HEIGHT])

    ti = TileIndex(
        tile_size=tile_size,
        slide_w=slide_w,
        slide_h=slide_h,
        seg_mask=seg_mask,
    )

    if not np.array_equal(tile_rows, ti.fg_tile_rows):
        msg = "tile_rows does not match fg_tile_rows"
        logger.error(msg)
        raise ValueError(msg)
    if not np.array_equal(tile_cols, ti.fg_tile_cols):
        msg = "tile_cols does not match fg_tile_cols"
        logger.error(msg)
        raise ValueError(msg)
    if not np.array_equal(tile_top_left_xs, ti.fg_tile_xs):
        msg = "tile_top_left_xs does not match fg_tile_xs"
        logger.error(msg)
        raise ValueError(msg)
    if not np.array_equal(tile_top_left_ys, ti.fg_tile_ys):
        msg = "tile_top_left_ys does not match fg_tile_ys"
        logger.error(msg)
        raise ValueError(msg)

    return ti


def _is_float_dtype(arr: np.ndarray) -> bool:
    return np.issubdtype(arr.dtype, np.floating)


def _is_int_dtype(arr: np.ndarray) -> bool:
    return np.issubdtype(arr.dtype, np.integer)


def _make_grid_shape(values: np.ndarray, nrows: int, ncols: int) -> tuple[int, ...]:
    """Return output grid shape for values of shape (N,) or (N,D)."""
    if values.ndim == 1:
        return (nrows, ncols)
    if values.ndim == 2:
        return (nrows, ncols, int(values.shape[1]))
    msg = f"values must be 1D or 2D; got shape {values.shape}"
    raise ValueError(msg)


def _make_background_grid(values: np.ndarray, shape: tuple[int, ...]) -> np.ndarray:
    if _is_float_dtype(values):
        dtype = values.dtype
        bg = np.full(shape, np.nan, dtype=dtype)
    elif _is_int_dtype(values):
        bg = np.zeros(shape, dtype=values.dtype)
    else:
        msg = f"values must be float or int; got dtype {values.dtype}"
        raise ValueError(msg)

    return bg
