"""Slide dispatching utilities (process/thread orchestration + progress)."""

from __future__ import annotations

import logging
import multiprocessing as mp
import queue
import threading
import traceback
from collections import Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import TYPE_CHECKING, Any, cast

from rich.progress import (
    BarColumn,
    Progress,
    ProgressColumn,
    SpinnerColumn,
    TaskID,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
from rich.text import Text

from pathfmtools.embedding_models import get_embedding_model
from pathfmtools.image import Slide
from pathfmtools.io.pid_lock import LockTimeoutError, PIDLock
from pathfmtools.io.schema import StoreKeys as SK
from pathfmtools.io.slide_data_store import SlideDataStore
from pathfmtools.utils import log_profiling_info

if TYPE_CHECKING:
    from pathlib import Path

    import torch

logger = logging.getLogger(__name__)


def dispatch_slide_workers(
    store_root: Path,
    slide_paths: list[Path],
    device_list: list[torch.device],
    *,
    patch_size: int | None,
    batch_size: int,
    segmenter: str,
    model_list: list[str] | None,
    delete_tiles: bool,
    continue_on_error: bool,
    no_auto_rescale: bool,
    skip_feature_embeddings: bool,
    skip_zeroshot_embeddings: bool,
) -> None:
    """Run one long-lived worker per device, fed by a shared queue, with central progress.

    Notes:
        - Requires explicit `patch_size` in multi-device mode to avoid nondeterministic inference.
        - Uses a Manager-backed queue for tasks and progress; workers use a spawn context.

    """
    store = SlideDataStore(store_root).open("a")
    needs_pre = [p for p in slide_paths if not store.check_tiles_present(p.stem)]
    if patch_size is None and needs_pre:
        msg = (
            f"--patch_size is required because {len(needs_pre)} slide(s) need preprocessing. "
            "Either provide --patch_size or preprocess slides first."
        )
        raise ValueError(msg)

    # Warn if multiple processes target the same device (possible GPU oversubscription)
    device_counts = Counter(device_list)
    gpu_counts = {k: v for k, v in device_counts.items() if k.type == "cuda"}
    if any(v > 1 for v in gpu_counts.values()):
        logger.warning(
            "Multiple workers per GPU detected: %s. This may oversubscribe device memory; adjust "
            "--splits if needed.",
            ", ".join([f"device {d}: {c} procs" for d, c in gpu_counts.items()]),
        )

    n_total = len(slide_paths)

    # Centralized progress + task distribution via manager
    with mp.Manager() as manager:
        progress_q = cast("mp.Queue[dict[str, Any] | None]", manager.Queue())
        task_q = cast("mp.Queue[Path | None]", manager.Queue())
        abort_event = manager.Event()

        # Fill task queue
        for slide_path in slide_paths:
            task_q.put(slide_path)

        with Progress(*_get_progress_columns()) as progress:
            worker_task_id_map: dict[str, int] = {}

            # Aggregate task across all slides
            agg_desc = f"All slides • slides={n_total}"
            agg_task_id = progress.add_task(agg_desc, total=n_total)

            # Per-worker tasks (indeterminate totals)
            worker_keys: list[str] = []
            for idx, dev in enumerate(device_list):
                worker_key = f"{dev!s}#{idx + 1}"
                worker_keys.append(worker_key)
                # Indeterminate tasks use spinners; show a per-worker completed count in the columns
                desc = f"{worker_key}"
                worker_task_id = progress.add_task(desc, total=None)
                worker_task_id_map[worker_key] = worker_task_id

            # Progress listener: updates per-worker and aggregate tasks
            listener_thread = threading.Thread(
                target=_progress_listener,
                args=(progress, progress_q, worker_task_id_map, agg_task_id),
                daemon=True,
            )
            listener_thread.start()

            # Add one sentinel per worker to indicate completion
            for _ in range(len(worker_keys)):
                task_q.put(None)

            # Launch long-lived workers (one per device)
            futures = []
            mp_ctx = mp.get_context("spawn")
            with ProcessPoolExecutor(max_workers=len(worker_keys), mp_context=mp_ctx) as ex:
                for idx, dev in enumerate(device_list):
                    worker_key = worker_keys[idx]
                    futures.append(
                        ex.submit(
                            _device_worker,
                            task_q=task_q,
                            progress_q=progress_q,
                            abort_event=abort_event,
                            worker_key=worker_key,
                            device=dev,
                            patch_size=patch_size,
                            batch_size=batch_size,
                            segmenter=segmenter,
                            model_list=model_list,
                            delete_tiles=delete_tiles,
                            store_root=store_root,
                            continue_on_error=continue_on_error,
                            no_auto_rescale=no_auto_rescale,
                            skip_feature_embeddings=skip_feature_embeddings,
                            skip_zeroshot_embeddings=skip_zeroshot_embeddings,
                        ),
                    )

                # Surface exceptions and coordinate abort
                first_exc: BaseException | None = None
                for f in as_completed(futures):
                    try:
                        f.result()
                    except BaseException as exc:  # noqa: BLE001
                        if not continue_on_error and first_exc is None:
                            first_exc = exc
                            abort_event.set()
                        # If continue_on_error is True, keep waiting for others
                if first_exc is not None:
                    # Stop listener and re-raise
                    progress_q.put(None)
                    listener_thread.join()
                    raise first_exc

            # Stop listener and finalize
            progress_q.put(None)
            listener_thread.join()


def _device_worker(
    task_q: mp.Queue,
    progress_q: mp.Queue | None,
    abort_event,
    worker_key: str,
    device: torch.device,
    patch_size: int,
    batch_size: int,
    segmenter: str,
    model_list: list[str] | None,
    delete_tiles: bool,
    store_root: Path,
    continue_on_error: bool,
    no_auto_rescale: bool,
    skip_feature_embeddings: bool,
    skip_zeroshot_embeddings: bool,
) -> None:
    """Long-lived worker: initializes models once, processes slides from a queue."""
    # Initialize model objects once per process/device
    patch_embedding_models = (
        [get_embedding_model(model_name)(device=device) for model_name in model_list]
        if model_list is not None
        else None
    )

    # SlideDataStore can create the store_root when opened, but here we need to check for a lease
    # before the store is opened.
    store_root.mkdir(parents=True, exist_ok=True)

    # Consume tasks until sentinel or abort
    while True:
        if abort_event.is_set():
            break
        try:
            slide_path = task_q.get(timeout=2)
        except queue.Empty:
            if abort_event.is_set():
                break
            continue
        if slide_path is None:
            break
        try:
            slide_id = slide_path.stem
            # Lease-based deferral, avoid two workers processing the same slide concurrently
            lease_target = store_root / f"{slide_id}.lease"
            try:
                with PIDLock(lease_target, timeout=2, poll=0.05, jitter=0.01):
                    slide = Slide(slide_path=slide_path, store_root=store_root)
                    log_profiling_info("Slide loaded")

                    if not slide.preprocessed:
                        slide.preprocess(patch_size=patch_size, segmenter=segmenter, verbose=False)

                    if patch_embedding_models is not None and not (
                        skip_feature_embeddings and skip_zeroshot_embeddings
                    ):
                        for patch_embedding_model in patch_embedding_models:
                            slide.embed_tiles(
                                model=patch_embedding_model,
                                batch_size=batch_size,
                                auto_rescale=not no_auto_rescale,
                                skip_feature_embeddings=skip_feature_embeddings,
                                skip_zeroshot_embeddings=skip_zeroshot_embeddings,
                                verbose=False,
                            )
                        log_profiling_info("dispatch_on_slides: patches embedded")

                    if delete_tiles:
                        slide.delete_tiles_from_disk()

                    log_profiling_info("dispatch_on_slides: end (pre-cleanup)")

                    # Notify central progress tracker if available
                    if progress_q is not None:
                        progress_q.put(
                            {
                                "event": "slide_done",
                                "device": device,
                                "worker_key": worker_key,
                                "slide_id": slide_id,
                            },
                        )
            except LockTimeoutError:
                # Put task back and try another
                task_q.put(slide_path)
                continue
        except Exception:
            if continue_on_error:
                traceback.print_exc()
            else:
                raise


class PercentColumn(ProgressColumn):
    """Render percentage safely; blank for indeterminate tasks."""

    def render(self, task) -> Text:  # type: ignore[override]
        if task.total is None:
            return Text("")
        # task.percentage is a float in [0,100]
        return Text(f"{task.percentage:>3.1f}%", style="progress.percentage")


class HybridSpinnerOrBarColumn(ProgressColumn):
    """Render a spinner for indeterminate tasks, otherwise a progress bar.

    Notes:
        - Replaces pulsing bars (which can flicker) with a spinner when `task.total is None`.
        - Keeps the standard BarColumn for determinate tasks to communicate overall progress.
    """

    def __init__(self) -> None:
        super().__init__()
        # Simple, low-contrast spinner that reads well in terminals
        self._spinner = SpinnerColumn(spinner_name="dots", style="cyan")
        self._bar = BarColumn(
            bar_width=None,
            complete_style="cyan",
            finished_style="cyan",
            pulse_style="cyan",
        )

    def render(self, task) -> Text:  # type: ignore[override]
        if task.total is None:
            return self._spinner.render(task)
        return self._bar.render(task)


class CompletedForIndeterminateColumn(ProgressColumn):
    """Show completed count for indeterminate tasks; blank for determinate tasks."""

    def render(self, task) -> Text:  # type: ignore[override]
        if task.total is None:
            return Text(f"{int(task.completed)} done", style="dim")
        return Text("")


class SlidesPerMinuteColumn(ProgressColumn):
    """Render throughput as slides/min for determinate tasks.

    - Uses task.completed and task.elapsed to compute a stable, easy-to-read rate.
    - Hidden for indeterminate tasks to avoid noise in per-worker rows.
    """

    def render(self, task) -> Text:  # type: ignore[override]
        # Only meaningful for determinate aggregate tasks
        if task.total is None:
            return Text("")
        elapsed = getattr(task, "elapsed", None)
        completed = getattr(task, "completed", 0) or 0
        if not elapsed or elapsed <= 0:
            return Text("— slides/min", style="dim")
        rate = (float(completed) / float(elapsed)) * 60.0
        return Text(f"{rate:0.2f} slides/min", style="magenta")


def _get_progress_columns() -> list[ProgressColumn | str]:
    # Columns are shared across tasks; custom columns render contextually to reduce flicker and noise.
    return [
        TextColumn("[bold blue]{task.description}", justify="right"),
        HybridSpinnerOrBarColumn(),
        PercentColumn(),
        "•",
        CompletedForIndeterminateColumn(),
        "•",
        TimeElapsedColumn(),
        "•",
        SlidesPerMinuteColumn(),
        "←→",
        TimeRemainingColumn(),
    ]


def _progress_listener(
    progress: Progress,
    progress_q: mp.Queue,
    task_map: dict[str, TaskID],
    agg_task: TaskID,
) -> None:
    while True:
        # Blocking, but threaded
        evt = progress_q.get()

        # Progress queue sentinel
        if evt is None:
            break
        if evt.get("event") == "slide_done":
            key = evt.get("worker_key")
            worker_task_id = cast("TaskID", task_map.get(key))
            progress.update(worker_task_id, advance=1)
            progress.update(agg_task, advance=1)
