"""Hibou foundation models.

@misc{nechaev_hibou_2024,
        title = {Hibou: {A} {Family} of {Foundational} {Vision} {Transformers} for {Pathology}},
        url = {http://arxiv.org/abs/2406.05074},
        doi = {10.48550/arXiv.2406.05074},
        author = {Nechaev, Dmitry and Pchelnikov, Alexey and Ivanova, Ekaterina},
        month = aug,
        year = {2024},
}
https://github.com/histai/hibou
https://huggingface.co/histai/hibou-b
https://huggingface.co/histai/hibou-L
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import torch
from transformers import AutoImageProcessor, AutoModel

from pathfmtools.embedding_models.embedding_model import EmbeddingModel
from pathfmtools.embedding_models.registry import register_model

if TYPE_CHECKING:
    from collections.abc import Callable

    from PIL import Image

logger = logging.getLogger(__name__)


@register_model(
    "hibou-b",
    embedding_dim=768,
    supports_zeroshot=False,
    supports_text=False,
)
class HibouBModel(EmbeddingModel):
    """Hibou-b model."""

    NAME = "hibou-b"
    # The paper does not specify the patch size used during training, but all patch level benchmarks
    # reported used (224 X 224) patches. Magnification was similarly not reported for training data,
    # but patch level benchmarks used patches at 20X magnification (if reported).
    # The demo notebook https://github.com/HistAI/hibou/blob/main/example.ipynb also resizes images
    # to (224 X 224).
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "cls"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.processor = AutoImageProcessor.from_pretrained(
            "histai/hibou-b",
            trust_remote_code=True,
        )
        self.model = AutoModel.from_pretrained("histai/hibou-b", trust_remote_code=True)
        self.model.eval()
        self.model.to(self.device)

        self.embedding_dim = (
            768  # https://huggingface.co/histai/hibou-b/blob/main/modeling_dinov2.py#L57
        )

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, None]:
        """Create the embeddings tensor for storing results."""
        embeddings = torch.empty(
            (n_patches, self.embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        return embeddings, None

    def _preprocess_batch(self, pil_images: list[Image.Image]) -> dict[str, torch.Tensor]:
        """Preprocess a batch of PIL images for the model."""
        inputs = self.processor(pil_images, return_tensors="pt")
        return {k: v.to(self.device) for k, v in inputs.items()}

    def _run_inference(self, batch_tensor: dict[str, torch.Tensor], **kwargs) -> Any:
        """Run model inference on preprocessed batch."""
        return self.model(**batch_tensor)

    def _extract_embeddings(self, model_output: Any, **kwargs) -> tuple[torch.Tensor, None]:
        """Extract embeddings from model output."""
        return model_output.last_hidden_state[:, 0, :], None

    def preprocess_input_tile(self) -> Callable[[Image.Image], dict[str, torch.Tensor]]:
        """Per-image CPU-only preprocessing callable (HF processor)."""
        # This class is currently NotImplemented; placeholder to satisfy interface if enabled.
        return lambda img: self.processor(img, return_tensors="pt")  # type: ignore[reportAttributeAccessIssue]


@register_model(
    "hibou-l",
    embedding_dim=768,
    supports_zeroshot=False,
    supports_text=False,
)
class HibouLModel(EmbeddingModel):
    """Hibou-L model."""

    NAME = "hibou-l"
    # The paper does not specify the patch size used during training, but all patch level benchmarks
    # reported used (224 X 224) patches. Magnification was similarly not reported for training data,
    # but patch level benchmarks used patches at 20X magnification (if reported).
    # The demo notebook https://github.com/HistAI/hibou/blob/main/example.ipynb also resizes images
    # to (224 X 224).
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "cls"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.processor = AutoImageProcessor.from_pretrained(
            "histai/hibou-L",
            trust_remote_code=True,
        )
        self.model = AutoModel.from_pretrained("histai/hibou-L", trust_remote_code=True)
        self.model.eval()
        self.model.to(self.device)

        self.embedding_dim = (
            768  # https://huggingface.co/histai/hibou-L/blob/main/modeling_dinov2.py#L57
        )

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, None]:
        """Create the embeddings tensor for storing results."""
        embeddings = torch.empty(
            (n_patches, self.embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        return embeddings, None

    def _preprocess_batch(self, pil_images: list[Image.Image]) -> dict[str, torch.Tensor]:
        """Preprocess a batch of PIL images for the model."""
        inputs = self.processor(pil_images, return_tensors="pt")
        return {k: v.to(self.device) for k, v in inputs.items()}

    def _run_inference(self, batch_tensor: dict[str, torch.Tensor], **kwargs) -> Any:
        """Run model inference on preprocessed batch."""
        return self.model(**batch_tensor)

    def _extract_embeddings(self, model_output: Any, **kwargs) -> tuple[torch.Tensor, None]:
        """Extract embeddings from model output."""
        return model_output.last_hidden_state[:, 0, :], None

    def preprocess_input_tile(self) -> Callable[[Image.Image], dict[str, torch.Tensor]]:
        """Per-image CPU-only preprocessing callable (HF processor)."""
        # This class is currently NotImplemented; placeholder to satisfy interface if enabled.
        return lambda img: self.processor(img, return_tensors="pt")  # type: ignore[reportAttributeAccessIssue]
