"""Sequential TileDataWriter bound to temporary HDF5 datasets.

Finalizes by validating counts, writing metadata columns, computing kept indices,
and atomically renaming tmp datasets/groups to their final names.
"""

from __future__ import annotations

import json
import logging
import uuid
from typing import TYPE_CHECKING

import numpy as np

from pathfmtools.io.schema import StoreKeys as SK

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    import h5py


class TileDataWriter:
    def __init__(
        self,
        handle: h5py.File,
        seg_mask: np.ndarray,
        patch_size: int,
        level: int,
        magnification: int | None,
        info: dict | None,
    ) -> None:
        self.handle = handle
        self.seg_mask = seg_mask
        self.patch_size = patch_size
        self.level = level
        self.magnification = magnification
        self.info = info

        self.n = seg_mask.sum()
        self.n_tile_rows, self.n_tile_cols = seg_mask.shape

        self._cursor = 0
        self._rows: list[int] = []
        self._cols: list[int] = []
        self._x: list[int] = []
        self._y: list[int] = []

        self._write_initial()

    def _write_initial(self) -> None:
        self.tmp_pixel_ds_name = f"{SK.DS_TILES}.__tmp__{uuid.uuid4().hex}"
        self.pixel_ds = self.handle.create_dataset(
            self.tmp_pixel_ds_name,
            shape=(self.n, self.patch_size, self.patch_size, 3),
            dtype=np.uint8,
            chunks=(min(64, self.n), self.patch_size, self.patch_size, 3),
        )
        self.tmp_meta_grp_name = f"{SK.DS_TILE_META}.__tmp__{uuid.uuid4().hex}"
        self.meta_grp = self.handle.create_group(self.tmp_meta_grp_name)
        self.tmp_seg_mask_ds_name = f"{SK.DS_SEG_MASK}.__tmp__{uuid.uuid4().hex}"
        self.seg_mask_ds = self.handle.create_dataset(
            self.tmp_seg_mask_ds_name,
            data=self.seg_mask.astype(np.bool_, copy=False),
            dtype=np.bool_,
        )

    # Streaming API
    def write_batch(
        self,
        tiles: np.ndarray,
        rows: np.ndarray,
        cols: np.ndarray,
        x_px: np.ndarray,
        y_px: np.ndarray,
    ) -> None:
        """Append a row-major batch of foreground tiles and metadata.

        Notes:
            - Assumes caller iteration is strictly increasing in row-major order
              over foreground tiles; this protects against subtle corruption and
              matches Slide.preprocess behavior.

        """
        b, p, p2, c = tiles.shape
        if p != self.patch_size or p2 != self.patch_size or c != 3:
            msg = "tiles must be of shape (B,P,P,3) with declared P"
            raise ValueError(msg)
        if not (len(rows) == len(cols) == len(x_px) == len(y_px) == b):
            msg = "metadata arrays must match batch size"
            raise ValueError(msg)

        rows = rows.astype(np.int32)
        cols = cols.astype(np.int32)
        # Validate mask inclusion and bounds
        if (
            (rows.min() < 0)
            or (cols.min() < 0)
            or (rows.max() >= self.n_tile_rows)
            or (cols.max() >= self.n_tile_cols)
        ):
            msg = f"(row,col) out of bounds for grid_shape {self.n_tile_rows}x{self.n_tile_cols}"
            raise ValueError(msg)
        if not np.all(self.seg_mask[rows, cols]):
            msg = "(row,col) must correspond to True in seg_mask"
            raise ValueError(msg)
        lin = rows * self.n_tile_cols + cols
        if np.any(lin[:-1] >= lin[1:]):
            msg = "(row,col) must be strictly increasing in row-major order"
            raise ValueError(msg)

        # Assign into tmp tiles at next cursor positions
        start = self._cursor
        end = start + b
        if end > self.n:
            msg = "writing beyond declared N tiles"
            raise ValueError(msg)
        self.pixel_ds[start:end, ...] = tiles.astype(np.uint8, copy=False)
        self._cursor = end

        # Buffer metadata
        self._rows.extend(int(v) for v in rows)
        self._cols.extend(int(v) for v in cols)
        self._x.extend(int(v) for v in x_px)
        self._y.extend(int(v) for v in y_px)

    def cleanup(self) -> None:
        # Validate counts
        if self._cursor != self.n:
            msg = f"written {self._cursor} does not equal N={self.n}"
            raise RuntimeError(msg)

        # Write metadata columns
        self.meta_grp.create_dataset(
            SK.TILE_META_TOP_LEFT_X,
            data=np.asarray(self._x, dtype=np.int32),
        )
        self.meta_grp.create_dataset(
            SK.TILE_META_TOP_LEFT_Y,
            data=np.asarray(self._y, dtype=np.int32),
        )
        self.meta_grp.create_dataset(SK.TILE_META_ROW, data=np.asarray(self._rows, dtype=np.int32))
        self.meta_grp.create_dataset(SK.TILE_META_COL, data=np.asarray(self._cols, dtype=np.int32))
        self.meta_grp.create_dataset(
            SK.TILE_META_WIDTH,
            data=np.full(self.n, self.patch_size, dtype=np.int32),
        )
        self.meta_grp.create_dataset(
            SK.TILE_META_HEIGHT,
            data=np.full(self.n, self.patch_size, dtype=np.int32),
        )
        if self.info is not None:
            self.meta_grp.attrs["info"] = json.dumps(self.info)

        tmp_handles = (self.tmp_pixel_ds_name, self.tmp_meta_grp_name, self.tmp_seg_mask_ds_name)
        final_handles = (SK.DS_TILES, SK.DS_TILE_META, SK.DS_SEG_MASK)
        for tmp_handle, final_handle in zip(tmp_handles, final_handles, strict=True):
            self.handle.flush()
            self.handle.move(tmp_handle, final_handle)
