"""Dispatch segmentation, patchification, and embedding on WSIs or pre-extracted patches.

Delegates orchestration to `pathfmtools.dispatch.slide_dispatch`.
"""

from __future__ import annotations

import datetime
import glob
import logging
import warnings
from pathlib import Path
from typing import Any

import pytz
import torch
import typer
from rich.console import Console
from rich.table import Table

from pathfmtools.dispatch.slide_dispatch import dispatch_slide_workers
from pathfmtools.embedding_models.registry import get_capabilities
from pathfmtools.utils.devices import parse_device

logger = logging.getLogger(__name__)

TIMESTAMP = datetime.datetime.now(tz=pytz.UTC).strftime("%Y-%m-%d-%H-%M-%S-%f-%Z")

app = typer.Typer()

# Warnings from loaded models
warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message="Importing from timm.models.registry is deprecated",
)


def _print_model_capabilities_table() -> None:
    """Render a Rich table of registered model capabilities."""
    caps_by_model = get_capabilities(None)
    console = Console()
    table = Table(title="Model Capabilities", show_lines=False)
    table.add_column("Model", style="bold")
    table.add_column("Embedding Dim", justify="right")
    table.add_column("Zeroshot Dim", justify="right")
    table.add_column("Supports Zeroshot", justify="center")

    for model_name in sorted(caps_by_model.keys()):
        caps = caps_by_model[model_name]
        table.add_row(
            model_name,
            str(caps["embedding_dim"]) if caps["embedding_dim"] is not None else "-",
            str(caps["zeroshot_dim"]) if caps["zeroshot_dim"] is not None else "-",
            "Y" if caps.get("supports_zeroshot", False) else "N",
        )

    console.print("\n")
    console.print(table)


def _help_callback(ctx: Any, param: Any, value: bool) -> None:  # noqa: ARG001, ANN001
    """Extend default --help with a Rich table of model capabilities.

    Implementation detail:
        - Prints standard Click/Typer help via ctx.get_help(), then renders table.
        - Exits immediately to prevent command execution.
    """
    if not value or ctx.resilient_parsing:
        return
    typer.echo(ctx.get_help())
    _print_model_capabilities_table()
    raise typer.Exit()


@app.command(context_settings={"help_option_names": []})
def cli(
    slide_path: str = typer.Option(
        ...,
        help="Path to a slide file or a glob pattern for slide files",
    ),
    store_root: Path = typer.Option(  # noqa: B008
        ...,
        help="Path to the root directory for storing processed slide data",
    ),
    gpu: list[int] | None = typer.Option(  # noqa: B008
        None,
        help=(
            "GPU device(s) to run the script on. If None, the script will run on CPU. Specify the "
            "same GPU multiple times to dispatch multiple worker processes to the same GPU."
        ),
    ),
    n_workers: int = typer.Option(
        2,
        help="Number of multiprocessing workers to use. Ignored if --gpu is provided.",
    ),
    patch_size: int | None = typer.Option(
        None,
        help="Side length (in pixels) of patches to extract",
    ),
    batch_size: int = typer.Option(8, help="Batch size for embedding patches"),
    segmenter: str = typer.Option("otsu", help="Segmentation method to use"),
    model: list[str] | None = typer.Option(  # noqa: B008
        None,
        help="Models to use for embedding patches. If None, patch embedding will be skipped.",
    ),
    delete_tiles: bool = typer.Option(False, help="Whether to delete extracted tiles"),
    continue_on_error: bool = typer.Option(False, help="Whether to continue on error"),
    no_auto_rescale: bool = typer.Option(
        False,
        help="Disable automatic patch rescaling to match model's expected receptive field",
    ),
    skip_feature_embeddings: bool = typer.Option(
        False,
        help="Skip generating and saving feature embeddings.",
    ),
    skip_zeroshot_embeddings: bool = typer.Option(
        False,
        help="Skip generating and saving zero-shot embeddings.",
    ),
    _help: bool = typer.Option(  # noqa: B008, FBT001
        False,
        "--help",
        "-h",
        help="Show this message and a table of model capabilities, then exit.",
        is_flag=True,
        is_eager=True,
        expose_value=False,
        callback=_help_callback,
    ),
) -> None:
    """Dispatch segmentation, patchification, and embedding on WSIs or pre-extracted patches."""
    slide_paths: list[Path] = [Path(p) for p in glob.glob(slide_path)]  # noqa: PTH207

    configure_root_logger(log_debug_info=False, log_file=store_root / f"{TIMESTAMP}.log")

    main(
        store_root=store_root,
        slide_paths=slide_paths,
        gpu=gpu,
        n_workers=n_workers,
        patch_size=patch_size,
        batch_size=batch_size,
        segmenter=segmenter,
        model=model,
        delete_tiles=delete_tiles,
        continue_on_error=continue_on_error,
        no_auto_rescale=no_auto_rescale,
        skip_feature_embeddings=skip_feature_embeddings,
        skip_zeroshot_embeddings=skip_zeroshot_embeddings,
    )


def main(
    store_root: Path,
    slide_paths: list[Path],
    gpu: list[int] | None,
    n_workers: int,
    patch_size: int | None,
    batch_size: int,
    segmenter: str,
    model: list[str] | None,
    delete_tiles: bool,
    continue_on_error: bool,
    no_auto_rescale: bool,
    skip_feature_embeddings: bool,
    skip_zeroshot_embeddings: bool,
) -> None:
    """Dispatch slide processing run."""
    # Validate embedding flag combinations before processing
    if model is not None:
        model_list = filter_models_by_capabilities(
            skip_feature_embeddings=skip_feature_embeddings,
            skip_zeroshot_embeddings=skip_zeroshot_embeddings,
            model_list=model,
        )
    else:
        model_list = None

    device_list = (
        [torch.device("cpu")] * n_workers if gpu is None else [parse_device(gpu) for gpu in gpu]
    )

    dispatch_slide_workers(
        store_root=store_root,
        slide_paths=slide_paths,
        device_list=device_list,
        patch_size=patch_size,
        batch_size=batch_size,
        segmenter=segmenter,
        model_list=model_list,
        delete_tiles=delete_tiles,
        continue_on_error=continue_on_error,
        no_auto_rescale=no_auto_rescale,
        skip_feature_embeddings=skip_feature_embeddings,
        skip_zeroshot_embeddings=skip_zeroshot_embeddings,
    )


def configure_root_logger(log_debug_info: bool, log_file: Path) -> None:
    """Configure the root logger.

    Args:
        log_debug_info (bool): Whether to log debug information.
        log_file (Path): The path to the file to which the log will be written.

    """
    level = logging.DEBUG if log_debug_info else logging.INFO
    formatter = logging.Formatter(
        (
            "%(asctime)s, %(levelname)-8s"
            "[%(filename)s:%(module)s:%(funcName)s:%(lineno)d] %(message)s"
        ),
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    console_handler = logging.StreamHandler()
    console_handler.setLevel(level)
    console_handler.setFormatter(formatter)
    log_file.parent.mkdir(parents=True, exist_ok=True)
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(level)
    file_handler.setFormatter(formatter)

    logging.basicConfig(level=level, handlers=[console_handler, file_handler])


def filter_models_by_capabilities(
    skip_feature_embeddings: bool,
    skip_zeroshot_embeddings: bool,
    model_list: list[str],
) -> list[str]:
    """Filter models based on requested embedding outputs.

    Args:
        skip_feature_embeddings: If True, feature embeddings are not required.
        skip_zeroshot_embeddings: If True, zero-shot embeddings are not required.
        model_list: Candidate model names to consider.

    Returns:
        List of model names that satisfy the requested capabilities.

    Raises:
        ValueError: When only zero-shot embeddings are requested but none of the
            provided models support zero-shot outputs.

    Notes:
        - When only zero-shot embeddings are requested, models that do not
          support zero-shot outputs are dropped with a warning.

    """
    # Only zero-shot embeddings requested but no models support it
    if skip_feature_embeddings and not skip_zeroshot_embeddings:
        zeroshot_supported_models = []

        for model_name in model_list:
            model_caps = get_capabilities(model_name)
            if model_caps["supports_zeroshot"]:
                zeroshot_supported_models.append(model_name)

        if len(zeroshot_supported_models) == 0:
            msg = (
                "Only zero-shot embeddings were requested, but none of the selected models "
                f"{model_list} support this feature."
            )
            logger.error(msg)
            raise ValueError(msg)
        if len(zeroshot_supported_models) != len(model_list):
            unsupported_model_names = [
                model_name
                for model_name in model_list
                if model_name not in zeroshot_supported_models
            ]
            msg = (
                f"Only zero-shot embeddings were requested, but the following models do not "
                f"support this feature: {unsupported_model_names}. These models will be skipped."
            )
            logger.warning(msg)

        return zeroshot_supported_models

    return model_list


if __name__ == "__main__":
    app()
