"""Midnight foundation model.

@misc{karasikov_training_2025,
        title = {Training state-of-the-art pathology foundation models with orders of magnitude less
                 data},
        url = {http://arxiv.org/abs/2504.05186},
        doi = {10.48550/arXiv.2504.05186},
        author = {Karasikov, Mikhail and Doorn, Joost van and Känzig, Nicolas and Cesur, Melis Erdal
                  and Horlings, Hugo Mark and Berke, Robert and Tang, Fei and Otálora, Sebastian},
        month = apr,
        year = {2025},
}
https://github.com/kaiko-ai/midnight
https://huggingface.co/kaiko-ai/midnight
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import torch
import torchvision.transforms.v2 as transforms
from transformers import AutoModel

from pathfmtools.embedding_models.embedding_model import EmbeddingModel
from pathfmtools.embedding_models.registry import register_model

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from collections.abc import Callable

    from PIL import Image


@register_model(
    "midnight-12k",
    embedding_dim=3072,
    supports_zeroshot=False,
    supports_text=False,
)
class Midnight12kModel(EmbeddingModel):
    """Midnight-12k model."""

    NAME = "midnight-12k"
    # From https://huggingface.co/kaiko-ai/midnight:
    # "Our models are trained on 224x224 images normalized with a mean of (0.5, 0.5, 0.5) and a
    # standard deviation of (0.5, 0.5, 0.5). Please ensure you apply these exact normalization
    # parameters when preparing your datasets for embedding extraction."
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224

    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "cls+mean"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.model = AutoModel.from_pretrained("kaiko-ai/midnight")
        self.model.eval()
        self.model.to(self.device)

        self.transform = transforms.Compose(
            [
                transforms.Resize(224),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
            ],
        )

        self.embedding_dim = 3_072

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, None]:
        """Create the embeddings tensor for storing results."""
        embeddings = torch.empty(
            (n_patches, self.embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        return embeddings, None

    def _preprocess_batch(self, pil_images: list[Image.Image]) -> torch.Tensor:
        """Preprocess a batch of PIL images for the model."""
        return torch.stack([self.transform(p).to(self.device) for p in pil_images])

    def preprocess_input_tile(self) -> Callable[[Image.Image], torch.Tensor]:
        """Per-image CPU-only preprocessing callable for Dataset workers."""
        return self.transform

    def _run_inference(self, batch_tensor: torch.Tensor, **kwargs) -> Any:
        """Run model inference on preprocessed batch."""
        return self.model(batch_tensor)

    def _extract_embeddings(self, model_output: Any, **kwargs) -> tuple[torch.Tensor, None]:
        """Extract embeddings from model output."""
        output = model_output.last_hidden_state
        # Extract CLS token and mean of patch tokens
        cls_embedding = output[:, 0, :]
        patch_embeddings = output[:, 1:, :]
        # Concatenate CLS token and mean of patch tokens
        embedding = torch.cat([cls_embedding, patch_embeddings.mean(1)], dim=-1)
        return embedding, None
