"""DatasetStore: focused write/read APIs with per-call locking.

This module introduces single-purpose write methods and simple read façades to
support streaming tile writes, metadata creation, and optional segmentation mask
storage. It coexists with the legacy ``pathfmtools.io.store.DatasetStore`` for
backwards compatibility during migration.

Conventions
----------
- Datasets:
  - ``/tiles``: shape (N, P, P, 3), dtype uint8
  - ``/tile_metadata``: JSON-serialized dict
  - ``/tile_segmentation_mask``: implementation-defined (bool grid)
- Chunking for ``/tiles``: default (min(64, N), P, P, 3)
- Locking defaults: timeout=300s, poll=0.25s, jitter=0.05s
"""

from __future__ import annotations

import json
import logging
import time
import uuid
from contextlib import contextmanager
from datetime import UTC, datetime
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Self, cast

import h5py
import numpy as np

from pathfmtools.io.pid_lock import PIDLock
from pathfmtools.io.schema import EMBEDDINGS_DTYPE, JSON_DTYPE, SCHEMA_VERSION
from pathfmtools.io.schema import StoreKeys as SK
from pathfmtools.io.tile_data_writer import TileDataWriter
from pathfmtools.utils.errors import DatasetExistsError, ShapeMismatchError
from pathfmtools.utils.model_id import canon_model_id

if TYPE_CHECKING:
    from collections.abc import Iterator


logger = logging.getLogger(__name__)


class SlideDataStore:
    """A directory-managed store of per-slide HDF5 files."""

    def __init__(self, root: Path, retries: int = 10, backoff_s: float = 1.0) -> None:
        """Initialize the store."""
        self.root = root
        self.retries = max(1, int(retries))
        self.backoff_s = float(backoff_s)

    def open(self, mode: Literal["r", "a"] = "a") -> Self:
        """Open the store for reading or appending.

        If in read mode, the store must exist. If in append mode, the store will be created if it
        does not exist.
        """
        if mode not in ("r", "a"):
            msg = "mode must be 'r' or 'a'"
            raise ValueError(msg)
        self.root = Path(self.root)
        if not self.root.exists():
            if mode == "r":
                raise FileNotFoundError(self.root)
            self.root.mkdir(parents=True, exist_ok=True)
        return self

    def get_slide_h5_path(self, slide_id: str, *, create_if_missing: bool = False) -> Path:
        """Return the per-slide HDF5 path, validating schema if present."""
        p = self.root / f"{slide_id}.h5"
        if not p.exists():
            if create_if_missing:
                tmp = p.with_name(p.name + f".tmp_{int(time.time())}")
                with PIDLock(p):
                    with h5py.File(tmp, "w") as f:
                        f.attrs["schema_version"] = SCHEMA_VERSION
                        f.flush()
                    tmp.replace(p)
            else:
                raise FileNotFoundError(p)

        with h5py.File(p, mode="r", libver="latest") as f:
            self._validate_schema(f)

        return p

    @contextmanager
    def get_tile_data_writer(
        self,
        slide_id: str,
        patch_size: int,
        seg_mask: np.ndarray,
        *,
        level: int,
        magnification: int | None = None,
        info: dict | None = None,
        force: bool = False,
    ) -> Iterator[TileDataWriter]:
        """Prepare temporary targets and yield a writer placeholder."""
        if patch_size <= 0:
            msg = "patch_size must be positive"
            raise ValueError(msg)

        h5_path = self.get_slide_h5_path(slide_id, create_if_missing=True)
        lock = PIDLock(h5_path, timeout=300.0, poll=0.25, jitter=0.05)
        lock.acquire()
        try:
            with h5py.File(h5_path, mode="a", libver="latest") as f:
                out_ds_names = [SK.DS_TILES, SK.DS_SEG_MASK, SK.DS_TILE_META]
                existing = [name for name in out_ds_names if name in f]
                if existing and not force:
                    msg = f"Datasets already exist: {existing}; use force=True to overwrite"
                    raise DatasetExistsError(msg)
                if existing and force:
                    for name in existing:
                        del f[name]

                writer = TileDataWriter(
                    handle=f,
                    seg_mask=seg_mask,
                    patch_size=patch_size,
                    level=level,
                    magnification=magnification,
                    info=info,
                )
                yield writer
                writer.cleanup()
        finally:
            lock.release()

    def delete_tiles(
        self,
        slide_id: str,
        *,
        lock_timeout: float = 5.0,
        lock_poll: float = 0.25,
        lock_jitter: float = 0.05,
    ) -> None:
        """Remove the tile dataset, rebuilding the file to reclaim space."""
        h5_path = self.get_slide_h5_path(slide_id)
        tmp = h5_path.with_name(h5_path.name + f".tmp_{int(time.time())}")
        with PIDLock(h5_path, timeout=lock_timeout, poll=lock_poll, jitter=lock_jitter):
            with h5py.File(h5_path, "r") as src, h5py.File(tmp, "w") as dst:
                # Copy all keys except patches
                for key in src:
                    if key == SK.DS_TILES:
                        continue
                    src.copy(key, dst)
                # Copy attrs
                for k, v in src.attrs.items():
                    dst.attrs[k] = v
            tmp.replace(h5_path)

    def read_tile_metadata(self, slide_id: str) -> dict[str, np.ndarray]:
        """Read tile metadata JSON dataset and return a dict copy."""

        def _read_ds_as_np(grp: h5py.Group, ds_name: str) -> np.ndarray:
            ds = cast("h5py.Dataset", grp[ds_name])
            return cast("np.ndarray", ds[()])

        h5_path = self.get_slide_h5_path(slide_id)
        with h5py.File(h5_path, mode="r", libver="latest") as f:
            if SK.DS_TILE_META not in f:
                raise KeyError(SK.DS_TILE_META)
            tile_meta_grp = cast("h5py.Group", f[SK.DS_TILE_META])
            tile_rows = _read_ds_as_np(tile_meta_grp, SK.TILE_META_ROW)
            tile_cols = _read_ds_as_np(tile_meta_grp, SK.TILE_META_COL)
            tile_top_left_xs = _read_ds_as_np(tile_meta_grp, SK.TILE_META_TOP_LEFT_X)
            tile_top_left_ys = _read_ds_as_np(tile_meta_grp, SK.TILE_META_TOP_LEFT_Y)
            tile_widths = _read_ds_as_np(tile_meta_grp, SK.TILE_META_WIDTH)
            tile_heights = _read_ds_as_np(tile_meta_grp, SK.TILE_META_HEIGHT)
            if not (
                len(tile_rows)
                == len(tile_cols)
                == len(tile_top_left_xs)
                == len(tile_top_left_ys)
                == len(tile_widths)
                == len(tile_heights)
            ):
                msg = "Tile metadata lengths must match"
                raise ValueError(msg)
            if not (np.all(tile_widths == tile_heights) and np.all(tile_widths == tile_widths[0])):
                msg = "Tile widths/heights are not consistent or not square"
                raise ValueError(msg)
            if not (
                np.all(tile_top_left_xs == tile_cols * tile_widths[0])
                and np.all(tile_top_left_ys == tile_rows * tile_widths[0])
            ):
                msg = "Tile top-left coordinates do not match (col,row)*P derivation"
                raise ValueError(msg)
            return {
                "rows": tile_rows,
                "cols": tile_cols,
                "top_left_xs": tile_top_left_xs,
                "top_left_ys": tile_top_left_ys,
                "widths": tile_widths,
                "heights": tile_heights,
            }

    def check_tiles_present(self, slide_id: str) -> bool:
        """Check if the tiles dataset exists."""
        try:
            h5_path = self.get_slide_h5_path(slide_id, create_if_missing=False)
        except FileNotFoundError:
            return False
        with h5py.File(h5_path, mode="r", libver="latest") as f:
            try:
                ds = cast("h5py.Dataset", f[SK.DS_TILES])
                return (ds.shape != ()) and not any(dim == 0 for dim in ds.shape)
            except KeyError:
                return False

    def read_seg_mask(self, slide_id: str) -> np.ndarray:
        """Read segmentation mask dataset and return a copy."""
        h5_path = self.get_slide_h5_path(slide_id)
        with h5py.File(h5_path, mode="r", libver="latest") as f:
            if SK.DS_SEG_MASK not in f:
                raise KeyError(SK.DS_SEG_MASK)
            return f[SK.DS_SEG_MASK][()]  # type: ignore[reportIndexIssue]

    def read_dataset(
        self,
        slide_id: str,
        name: str,
    ) -> np.ndarray | dict:
        """Read a dataset by name and return a copy.

        For ``/tile_metadata``, a Python ``dict`` is returned.
        """
        h5_path = self.get_slide_h5_path(slide_id)
        # Reads are non-exclusive; rely on writer-side atomic moves
        with h5py.File(h5_path, mode="r", libver="latest") as f:
            if name not in f:
                raise KeyError(name)
            obj = f[name]
            # Standardize access for known JSON datasets
            if name == SK.DS_SLIDE_META:
                raw = obj[()]  # type: ignore[reportAttributeAccessIssue]
                return json.loads(raw)  # type: ignore[reportArgumentType]
            # Provide a dict-of-arrays view for group-based tile metadata
            if name == SK.DS_TILE_META:
                if isinstance(obj, h5py.Group):
                    tile_meta_grp = obj
                    return {
                        "rows": tile_meta_grp[SK.TILE_META_ROW][()],
                        "cols": tile_meta_grp[SK.TILE_META_COL][()],
                        "top_left_xs": tile_meta_grp[SK.TILE_META_TOP_LEFT_X][()],
                        "top_left_ys": tile_meta_grp[SK.TILE_META_TOP_LEFT_Y][()],
                        "widths": tile_meta_grp[SK.TILE_META_WIDTH][()],
                        "heights": tile_meta_grp[SK.TILE_META_HEIGHT][()],
                    }
                # Back-compat: JSON dataset (legacy). Return parsed dict.
                raw = obj[()]  # type: ignore[reportAttributeAccessIssue]
                return json.loads(raw)  # type: ignore[reportArgumentType]

            # Reject other groups to prevent misuse
            if isinstance(obj, h5py.Group):
                msg = f"'{name}' is a group; use a dedicated reader for structured data"
                raise TypeError(msg)

            return obj[()]  # type: ignore[reportIndexIssue]

    def write_slide_metadata(
        self,
        slide_id: str,
        *,
        slide_width: int,
        slide_height: int,
        magnification: int,
        patch_size: int,
        slide_fpath: Path,
        segmentation_method: str,
        prop_foreground: float,
        force: bool = False,
        info: dict | None = None,
        lock_timeout: float = 300.0,
        lock_poll: float = 0.25,
        lock_jitter: float = 0.05,
    ) -> None:
        """Write slide-level metadata as a JSON dataset.

        Args:
            slide_id: Slide identifier.
            meta: JSON-serializable dict.
            force: If True, delete and rewrite when dataset exists; otherwise error.
            info: Optional dataset-level info attribute (JSON-serialized).
            lock_timeout: PIDLock timeout in seconds.
            lock_poll: Lock poll interval in seconds.
            lock_jitter: Poll jitter in seconds.

        Raises:
            DatasetExistsError: When dataset exists and ``force=False``.

        """
        meta = {
            SK.SLIDE_META_ID: slide_id,
            SK.SLIDE_META_WIDTH: int(slide_width),
            SK.SLIDE_META_HEIGHT: int(slide_height),
            SK.SLIDE_META_MAGNIFICATION: int(magnification),
            SK.SLIDE_META_PATCH_SIZE: int(patch_size),
            SK.SLIDE_META_SLIDE_FPATH: str(slide_fpath),
            SK.SLIDE_META_SEGMENTATION_METHOD: segmentation_method,
            SK.SLIDE_META_PROP_FOREGROUND: prop_foreground,
        }

        try:
            raw = json.dumps(meta)
        except Exception as e:
            msg = f"Slide metadata for {slide_id} is not JSON serializable: {e}"
            raise TypeError(msg) from e

        h5_path = self.get_slide_h5_path(slide_id, create_if_missing=True)
        with PIDLock(h5_path, timeout=lock_timeout, poll=lock_poll, jitter=lock_jitter):  # noqa: SIM117
            with h5py.File(h5_path, mode="a", libver="latest") as f:
                if SK.DS_SLIDE_META in f:
                    if not force:
                        msg = f"Dataset '{SK.DS_SLIDE_META}' already exists in {h5_path}"
                        raise DatasetExistsError(msg)
                    del f[SK.DS_SLIDE_META]
                self._create_dataset(
                    f,
                    SK.DS_SLIDE_META,
                    create_kwargs={"data": raw, "dtype": JSON_DTYPE},
                    info=info,
                )

    def read_slide_metadata(self, slide_id: str) -> dict:
        """Read slide-level metadata JSON dataset and return a dict copy."""
        h5_path = self.get_slide_h5_path(slide_id)
        with h5py.File(h5_path, mode="r", libver="latest") as f:
            if SK.DS_SLIDE_META not in f:
                raise KeyError(SK.DS_SLIDE_META)
            raw = f[SK.DS_SLIDE_META][()]  # type: ignore[reportIndexIssue]
            return json.loads(raw)  # type: ignore[reportArgumentType]

    def write_embeddings(
        self,
        slide_id: str,
        model_id: str,
        kind: Literal[SK.TILE_FEATURE_EMBEDDINGS, SK.TILE_ZEROSHOT_EMBEDDINGS],
        emb: np.ndarray,
        *,
        prompts: dict | None = None,
        force: bool = False,
        lock_timeout: float = 5.0,
        lock_poll: float = 0.25,
        lock_jitter: float = 0.05,
    ) -> dict:
        """Write embeddings dataset under grouped layout.

        Structure: ``/embeddings/{canon}/{kind}`` with ``float16`` dtype, and
        JSON meta at ``/embeddings/{canon}/meta``.

        Returns the meta dict written.
        """
        h5_path = self.get_slide_h5_path(slide_id, create_if_missing=True)
        canon = canon_model_id(model_id)
        if emb.dtype != EMBEDDINGS_DTYPE:
            emb = emb.astype(EMBEDDINGS_DTYPE, copy=False)

        with PIDLock(h5_path, timeout=lock_timeout, poll=lock_poll, jitter=lock_jitter):  # noqa: SIM117
            with h5py.File(h5_path, mode="a", libver="latest") as f:
                if SK.DS_TILE_META not in f:
                    msg = "/tile_metadata must exist before writing embeddings"
                    raise KeyError(msg)
                # Read N from group-based schema (any column length suffices)
                n = int(f[SK.DS_TILE_META][SK.TILE_META_ROW][()].size)
                if int(emb.shape[0]) != n:
                    msg = f"Embeddings length {emb.shape[0]} does not match N={n} tiles"
                    raise ValueError(msg)

                top_grp = f.require_group(SK.TILE_EMBEDDINGS_ROOT)
                model_grp = top_grp.require_group(canon)

                # Write or overwrite the embeddings dataset
                if kind in model_grp:
                    if not force:
                        msg = f"Dataset '{kind}' already exists for model {model_id} in {h5_path}"
                        raise DatasetExistsError(msg)
                    del model_grp[kind]
                ds = model_grp.create_dataset(kind, data=emb, dtype=EMBEDDINGS_DTYPE)
                ds.attrs["dims"] = emb.shape

                meta = {
                    "model_id_raw": model_id,
                    "canon": canon,
                    "dims": list(map(int, emb.shape)),
                    "created_at": datetime.now(UTC).isoformat(),
                }
                if prompts is not None:
                    meta["prompts"] = prompts

                raw = json.dumps(meta)
                if "meta" in model_grp:
                    del model_grp["meta"]
                model_grp.create_dataset("meta", data=raw, dtype=JSON_DTYPE)
                return meta

    def read_embeddings(
        self,
        slide_id: str,
        model_id: str,
        kind: Literal[SK.TILE_FEATURE_EMBEDDINGS, SK.TILE_ZEROSHOT_EMBEDDINGS],
    ) -> np.ndarray:
        h5_path = self.get_slide_h5_path(slide_id)
        canon = canon_model_id(model_id)
        with h5py.File(h5_path, mode="r", libver="latest") as f:
            if SK.TILE_EMBEDDINGS_ROOT not in f:
                raise KeyError(SK.TILE_EMBEDDINGS_ROOT)
            emb_grp = f[SK.TILE_EMBEDDINGS_ROOT]
            if canon not in emb_grp:
                raise KeyError(canon)
            model_grp = emb_grp[canon]
            if kind not in model_grp:
                raise KeyError(kind)
            return model_grp[kind][()]  # type: ignore[reportIndexIssue]

    @contextmanager
    def open_slide_store_file_readonly(
        self,
        slide_id: str,
    ) -> Iterator[h5py.File]:
        """Open a dataset handle (read-only) for zero-copy access within a context manager."""
        h5_path = self.get_slide_h5_path(slide_id)
        f = h5py.File(h5_path, mode="r", libver="latest")
        try:
            yield f
        finally:
            f.close()

    @staticmethod
    def _validate_schema(f: h5py.File) -> None:
        v = f.attrs.get("schema_version", None)
        if v is None:
            msg = "Missing schema_version attribute in HDF5 file"
            raise ValueError(msg)
        if str(v) != SCHEMA_VERSION:
            msg = f"Unsupported schema_version {v}; expected {SCHEMA_VERSION}"
            raise ValueError(msg)

    def _create_dataset(
        self,
        f: h5py.File,
        name: str,
        *,
        create_kwargs: dict,
        info: dict | None = None,
    ) -> None:
        """Create a dataset atomically via tmp+move.

        Args:
            f: Open HDF5 file handle in append mode.
            name: Final dataset path/name.
            create_kwargs: Passed to ``f.create_dataset`` for the temporary dataset.
            info: Optional metadata dict stored in ``ds.attrs['info']`` as JSON.

        """
        if name in f:
            msg = f"Dataset '{name}' already exists in {f.filename}"
            raise DatasetExistsError(msg)

        tmp_name = f"{name}.__tmp__{uuid.uuid4().hex}"
        ds = f.create_dataset(tmp_name, **create_kwargs)
        if info is not None:
            ds.attrs["info"] = json.dumps(info)
        f.flush()
        f.move(tmp_name, name)

    def _check_tile_data_alignment(
        self,
        f: h5py.File,
        *,
        tiles_n: int | None = None,
        meta_len: int | None = None,
        mask_count: int | None = None,
    ) -> None:
        """Validate that tile counts align across available datasets.

        Checks any available combination of tile pixels (``/tiles``), tile metadata
        (``/tile_metadata``), and segmentation mask (``/tile_segmentation_mask``).

        Args:
            f: Open HDF5 file handle.
            tiles_n: Optional override for number of tiles to validate against
                (e.g., when creating the tiles dataset).
            meta_len: Optional override for metadata length to validate against
                (e.g., when writing metadata).
            mask_count: Optional override for mask cell count to validate against
                (e.g., when writing segmentation mask).

        Raises:
            InvalidShapeError: If any two available counts disagree.

        """
        # Derive counts from file when not provided via overrides. Keep fast and simple.
        if tiles_n is None and SK.DS_TILES in f:
            tiles_n = int(f[SK.DS_TILES].shape[0])  # type: ignore[reportAttributeAccessIssue]
        if meta_len is None and SK.DS_TILE_META in f:
            if isinstance(f[SK.DS_TILE_META], h5py.Group):
                meta_len = int(f[f"{SK.DS_TILE_META}/{SK.TILE_META_ROW}"].size)
            else:
                meta_len = len(json.loads(f[SK.DS_TILE_META][()]))  # type: ignore[reportAttributeAccessIssue]
        if mask_count is None and SK.DS_SEG_MASK in f:
            mask_count = f[SK.DS_SEG_MASK][()].sum()  # type: ignore[reportAttributeAccessIssue, reportIndexIssue]

        counts: list[tuple[str, int]] = []
        if tiles_n is not None:
            counts.append(("tiles", int(tiles_n)))
        if meta_len is not None:
            counts.append(("metadata", int(meta_len)))
        if mask_count is not None:
            counts.append(("seg_mask", int(mask_count)))

        # Compare any available pair
        for i in range(len(counts)):
            for j in range(i + 1, len(counts)):
                (ki, vi), (kj, vj) = counts[i], counts[j]
                if vi != vj:
                    msg = f"Inconsistent tile counts in {f.filename}: {ki}={vi} vs {kj}={vj}"
                    raise ShapeMismatchError(msg)
