"""UNI-2 foundation model.

@article{chen_towards_2024,
        title = {Towards a general-purpose foundation model for computational pathology},
        issn = {1078-8956, 1546-170X},
        url = {https://www.nature.com/articles/s41591-024-02857-3},
        doi = {10.1038/s41591-024-02857-3},
        journal = {Nature Medicine},
        author = {Chen, Richard J. and Ding, Tong and Lu, Ming Y. and Williamson, Drew F. K. and
                  Jaume, Guillaume and Song, Andrew H. and Chen, Bowen and Zhang, Andrew and
                  Shao, Daniel and Shaban, Muhammad and Williams, Mane and Oldenburg, Lukas and
                  Weishaupt, Luca L. and Wang, Judy J. and Vaidya, Anurag and Le, Long Phi and
                  Gerber, Georg and Sahai, Sharifa and Williams, Walt and Mahmood, Faisal},
        month = mar,
        year = {2024},
}
https://github.com/mahmoodlab/UNI
https://huggingface.co/MahmoodLab/UNI2-h
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, cast

import timm
import torch
from timm.data.config import resolve_data_config
from timm.data.transforms_factory import create_transform

from pathfmtools.embedding_models.embedding_model import EmbeddingModel
from pathfmtools.embedding_models.registry import register_model

if TYPE_CHECKING:
    from collections.abc import Callable

    from PIL import Image

logger = logging.getLogger(__name__)


@register_model(
    "uni2",
    embedding_dim=1536,
    supports_zeroshot=False,
    supports_text=False,
)
class UNI2Model(EmbeddingModel):
    """UNI patch embedding model."""

    NAME = "uni2"
    # From https://pmc.ncbi.nlm.nih.gov/articles/PMC11403354/:
    # "For all images used in ROI tasks and extracted patches for MIL in slide tasks, across all
    # models, all feature extraction operations are performed on resized 224 X 224 images at X20
    # magnification"
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "global"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)

        # pretrained=True needed to load UNI2-h weights (and download weights for the first time)
        timm_kwargs = {
            "img_size": 224,
            "patch_size": 14,
            "depth": 24,
            "num_heads": 24,
            "init_values": 1e-5,
            "embed_dim": 1536,
            "mlp_ratio": 2.66667 * 2,
            "num_classes": 0,
            "no_embed_class": True,
            "mlp_layer": timm.layers.SwiGLUPacked,  # type: ignore[reportPrivateImportUsage]
            "act_layer": torch.nn.SiLU,
            "reg_tokens": 8,
            "dynamic_img_size": True,
        }
        self.model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)  # type: ignore[reportPrivateImportUsage]
        self.preprocess = create_transform(
            **resolve_data_config(self.model.pretrained_cfg, model=self.model),
            is_training=False,
        )
        self.model.eval()
        self.model.to(self.device)
        self.embedding_dim = cast(int, self.model.embed_dim)

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, None]:
        """Create the embeddings tensor for storing results."""
        embeddings = torch.empty((n_patches, self.embedding_dim)).to(self.device).to(torch.float16)
        return embeddings, None

    def _preprocess_batch(self, pil_images: list[Image.Image]) -> torch.Tensor:
        """Preprocess a batch of PIL images for the model."""
        return torch.stack(
            [
                self.preprocess(p).to(self.device)  # type: ignore[reportCallIssue]
                for p in pil_images
            ],
        )

    def preprocess_input_tile(self) -> Callable[[Image.Image], torch.Tensor]:
        """Per-image CPU-only preprocessing callable for Dataset workers."""
        return self.preprocess  # type: ignore[reportReturnType]

    def _run_inference(self, batch_tensor: torch.Tensor, **kwargs) -> torch.Tensor:
        """Run model inference on preprocessed batch."""
        return self.model(batch_tensor)

    def _extract_embeddings(
        self,
        model_output: torch.Tensor,
        **kwargs,
    ) -> tuple[torch.Tensor, None]:
        """Extract embeddings from model output."""
        return model_output, None
