"""Virchow foundation models.

@article{vorontsov_foundation_2024,
        title = {A foundation model for clinical-grade computational pathology and rare cancers
                 detection},
        issn = {1546-170X},
        url = {https://www.nature.com/articles/s41591-024-03141-0},
        doi = {10.1038/s41591-024-03141-0},
        journal = {Nature Medicine},
        author = {Vorontsov, Eugene and Bozkurt, Alican and Casson, Adam and Shaikovski, George and
                  Zelechowski, Michal and Severson, Kristen and Zimmermann, Eric and Hall, James and
                  Tenenholtz, Neil and Fusi, Nicolo and Yang, Ellen and Mathieu, Philippe and
                  van Eck, Alexander and Lee, Donghun and Viret, Julian and Robert, Eric and
                  Wang, Yi Kan and Kunz, Jeremy D. and Lee, Matthew C. H. and Bernhard, Jan H. and
                  Godrich, Ran A. and Oakley, Gerard and Millar, Ewan and Hanna, Matthew and
                  Wen, Hannah and Retamero, Juan A. and Moye, William A. and Yousfi, Razik and
                  Kanan, Christopher and Klimstra, David S. and Rothrock, Brandon and
                  Liu, Siqi and Fuchs, Thomas J.},
        month = oct,
        year = {2024},
}
https://huggingface.co/paige-ai/Virchow

@misc{zimmermann_virchow2_2024,
        title = {Virchow2: {Scaling} {Self}-{Supervised} {Mixed} {Magnification} {Models} in
                 {Pathology}},
        url = {http://arxiv.org/abs/2408.00738},
        doi = {10.48550/arXiv.2408.00738},
        publisher = {arXiv},
        author = {Zimmermann, Eric and Vorontsov, Eugene and Viret, Julian and Casson, Adam and
                  Zelechowski, Michal and Shaikovski, George and Tenenholtz, Neil and Hall, James
                  and Klimstra, David and Yousfi, Razik and Fuchs, Thomas and Fusi, Nicolo and
                  Liu, Siqi and Severson, Kristen},
        month = nov,
        year = {2024},
}
https://huggingface.co/paige-ai/Virchow2
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked

from pathfmtools.embedding_models.embedding_model import EmbeddingModel
from pathfmtools.embedding_models.registry import register_model

if TYPE_CHECKING:
    from collections.abc import Callable

    from PIL import Image

logger = logging.getLogger(__name__)


@register_model(
    "virchow",
    embedding_dim=2560,
    supports_zeroshot=False,
    supports_text=False,
)
class VirchowModel(EmbeddingModel):
    """Virchow model."""

    NAME = "virchow"
    # From https://www.nature.com/articles/s41591-024-03141-0:
    # "For a 224 X 224 input tile image, a Virchow embedding is defined as the concatenation of the
    # class token and the mean across all 256 of the other predicted tokens"
    # "Virchow is used to generate tile-level embeddings on all the evaluated datasets with
    # 224 X 224 resolution at X20 magnification"
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "cls+mean"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.model = timm.create_model(
            "hf-hub:paige-ai/Virchow",
            pretrained=True,
            mlp_layer=SwiGLUPacked,
            act_layer=torch.nn.SiLU,
        )
        self.model.eval()
        self.model.to(self.device)

        # Create image transform
        config = resolve_data_config(self.model.pretrained_cfg, model=self.model)
        self.transform = create_transform(**config, is_training=False)

        self.embedding_dim = 2560

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, None]:
        """Create the embeddings tensor for storing results."""
        embeddings = torch.empty(
            (n_patches, self.embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        return embeddings, None

    def preprocess_input_tile(self) -> Callable[[Image.Image], torch.Tensor]:
        """Per-image CPU-only preprocessing callable for Dataset workers."""
        return self.transform

    def _run_inference(self, batch_tensor: torch.Tensor, **kwargs) -> torch.Tensor:
        """Run model inference on preprocessed batch."""
        return self.model(batch_tensor)

    def _extract_embeddings(
        self, model_output: torch.Tensor, **kwargs
    ) -> tuple[torch.Tensor, None]:
        """Extract embeddings from model output."""
        # Extract class token and patch tokens
        class_token = model_output[:, 0]
        patch_tokens = model_output[:, 1:]
        # Concatenate class token and mean of patch tokens
        embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1)
        return embedding, None


@register_model(
    "virchow2",
    embedding_dim=2560,
    supports_zeroshot=False,
    supports_text=False,
)
class Virchow2Model(EmbeddingModel):
    """Virchow2 model."""

    NAME = "virchow2"
    # From https://arxiv.org/pdf/2408.00738:
    # "A ViT-B/16 is trained using variations of DINOv2 [45] on 224 X 224 global views and 96 X 96
    # local views sampled from image tiles"
    # The model is designed to work at multiple magnification scales.
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "cls+mean"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.model = timm.create_model(
            "hf-hub:paige-ai/Virchow2",
            pretrained=True,
            mlp_layer=SwiGLUPacked,
            act_layer=torch.nn.SiLU,
        )
        self.model.eval()
        self.model.to(self.device)

        # Create image transform
        config = resolve_data_config(self.model.pretrained_cfg, model=self.model)
        self.transform = create_transform(**config, is_training=False)

        self.embedding_dim = 2560

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, None]:
        """Create the embeddings tensor for storing results."""
        embeddings = torch.empty(
            (n_patches, self.embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        return embeddings, None

    def preprocess_input_tile(self) -> Callable[[Image.Image], torch.Tensor]:
        """Per-image CPU-only preprocessing callable for Dataset workers."""
        return self.transform

    def _run_inference(self, batch_tensor: torch.Tensor, **kwargs) -> torch.Tensor:
        """Run model inference on preprocessed batch."""
        return self.model(batch_tensor)

    def _extract_embeddings(
        self, model_output: torch.Tensor, **kwargs
    ) -> tuple[torch.Tensor, None]:
        """Extract embeddings from model output."""
        # Extract class token and patch tokens (ignore register tokens 1-4)
        class_token = model_output[:, 0]
        patch_tokens = model_output[:, 5:]  # Skip register tokens
        # Concatenate class token and mean of patch tokens
        embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1)
        return embedding, None
