"""Phikon foundation models.

@misc{filiot_scaling_2023,
        title = {Scaling {Self}-{Supervised} {Learning} for {Histopathology} with {Masked} {Image}
                 {Modeling}},
        url = {https://www.medrxiv.org/content/10.1101/2023.07.21.23292757v2},
        doi = {10.1101/2023.07.21.23292757},
        author = {Filiot, Alexandre and Ghermi, Ridouane and Olivier, Antoine and Jacob, Paul and
                  Fidon, Lucas and Kain, Alice Mac and Saillard, Charlie and
                  Schiratti, Jean-Baptiste},
        month = sep,
        year = {2023},
}
https://github.com/owkin/HistoSSLscaling
https://huggingface.co/owkin/phikon

@misc{filiot_phikon-v2_2024,
        title = {Phikon-v2, {A} large and public feature extractor for biomarker prediction},
        url = {http://arxiv.org/abs/2409.09173},
        doi = {10.48550/arXiv.2409.09173},
        author = {Filiot, Alexandre and Jacob, Paul and Kain, Alice Mac and Saillard, Charlie},
        month = sep,
        year = {2024},
}
https://huggingface.co/owkin/phikon-v2
"""

from __future__ import annotations

import logging
from functools import partial
from typing import TYPE_CHECKING, Any

import torch
from transformers import AutoImageProcessor, AutoModel

from pathfmtools.embedding_models.embedding_model import EmbeddingModel
from pathfmtools.embedding_models.registry import register_model
from pathfmtools.utils.torch import hf_apply_processor

if TYPE_CHECKING:
    from collections.abc import Callable

    from PIL import Image

logger = logging.getLogger(__name__)


@register_model(
    "phikon",
    embedding_dim=768,
    supports_zeroshot=False,
    supports_text=False,
)
class PhikonModel(EmbeddingModel):
    """Phikon model."""

    NAME = "phikon"
    # From https://www.medrxiv.org/content/10.1101/2023.07.21.23292757v2.full.pdf:
    # "For each slide, non-overlapping tiles are extracted at 20X magnification (0.5µm/px) with a
    # fixed size of 224 X 224 pixels"
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "cls"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.processor = AutoImageProcessor.from_pretrained("owkin/phikon")
        self.model = AutoModel.from_pretrained("owkin/phikon")
        self.model.eval()
        self.model.to(self.device)

        self.embedding_dim = 768

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, None]:
        """Create the embeddings tensor for storing results."""
        embeddings = torch.empty(
            (n_patches, self.embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        return embeddings, None

    def _preprocess_batch(self, pil_images: list[Image.Image]) -> dict[str, torch.Tensor]:
        """Preprocess a batch of PIL images for the model."""
        inputs = self.processor(pil_images, return_tensors="pt")
        return {k: v.to(self.device) for k, v in inputs.items()}

    def preprocess_input_tile(self) -> Callable[[Image.Image], dict[str, torch.Tensor]]:
        """Per-image CPU-only preprocessing callable (HF processor adapter)."""
        return partial(hf_apply_processor, self.processor)

    def _run_inference(self, batch_tensor: dict[str, torch.Tensor], **kwargs) -> Any:
        """Run model inference on preprocessed batch."""
        return self.model(**batch_tensor)

    def _extract_embeddings(self, model_output: Any, **kwargs) -> tuple[torch.Tensor, None]:
        """Extract embeddings from model output."""
        # Extract CLS token (first token)
        return model_output.last_hidden_state[:, 0, :], None


@register_model(
    "phikon2",
    embedding_dim=1024,
    supports_zeroshot=False,
    supports_text=False,
)
class Phikon2Model(EmbeddingModel):
    """Phikon-v2 model."""

    NAME = "phikon2"
    # From https://arxiv.org/pdf/2409.09173:
    # "First for each WSI, Nt non-overlapping histology tiles of size 224x224 are extracted at
    # 20x magnification (0.5 micrometers per pixel)"
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "cls"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.processor = AutoImageProcessor.from_pretrained("owkin/phikon-v2")
        self.model = AutoModel.from_pretrained("owkin/phikon-v2")
        self.model.eval()
        self.model.to(self.device)

        self.embedding_dim = 1024

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, None]:
        """Create the embeddings tensor for storing results."""
        embeddings = torch.empty(
            (n_patches, self.embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        return embeddings, None

    def _preprocess_batch(self, pil_images: list[Image.Image]) -> dict[str, torch.Tensor]:
        """Preprocess a batch of PIL images for the model."""
        inputs = self.processor(pil_images, return_tensors="pt")
        return {k: v.to(self.device) for k, v in inputs.items()}

    def preprocess_input_tile(self) -> Callable[[Image.Image], dict[str, torch.Tensor]]:
        """Per-image CPU-only preprocessing callable (HF processor adapter)."""
        return partial(hf_apply_processor, self.processor)

    def _run_inference(self, batch_tensor: dict[str, torch.Tensor], **kwargs) -> Any:
        """Run model inference on preprocessed batch."""
        return self.model(**batch_tensor)

    def _extract_embeddings(self, model_output: Any, **kwargs) -> tuple[torch.Tensor, None]:
        """Extract embeddings from model output."""
        # Extract CLS token (first token)
        return model_output.last_hidden_state[:, 0, :], None
