"""MUSK foundation model.

@article{xiang_visionlanguage_2025,
        title = {A vision–language foundation model for precision oncology},
        issn = {1476-4687},
        url = {https://www.nature.com/articles/s41586-024-08378-w},
        doi = {10.1038/s41586-024-08378-w},
        journal = {Nature},
        author = {Xiang, Jinxi and Wang, Xiyue and Zhang, Xiaoming and Xi, Yinghua and
                  Eweje, Feyisope and Chen, Yijiang and Li, Yuchen and Bergstrom, Colin and
                  Gopaulchan, Matthew and Kim, Ted and Yu, Kun-Hsing and Willens, Sierra and
                  Olguin, Francesca Maria and Nirschl, Jeffrey J. and Neal, Joel and
                  Diehn, Maximilian and Yang, Sen and Li, Ruijiang},
        month = feb,
        year = {2025},
}
https://github.com/lilab-stanford/MUSK
https://huggingface.co/xiangjx/musk
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any

import torch
import torchvision

# (importing modeling appears to register the model with timm, even though it's not explicitly used)
from musk import modeling, utils  # noqa: F401
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.models._factory import create_model
from transformers import XLMRobertaTokenizer

from pathfmtools.embedding_models.embedding_model import EmbeddingModel
from pathfmtools.embedding_models.registry import register_model

if TYPE_CHECKING:
    from collections.abc import Callable

    import numpy as np
    from PIL import Image

logger = logging.getLogger(__name__)


@register_model(
    "musk",
    embedding_dim=2048,
    zeroshot_dim=1024,
    supports_zeroshot=True,
    supports_text=True,
)
class MUSKModel(EmbeddingModel):
    """MUSK patch model."""

    NAME = "musk"
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = True
    SUPPORTS_ZEROSHOT = True
    POOLING_RULE = "global"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)

        self.model = create_model("musk_large_patch16_384")
        utils.load_model_and_may_interpolate("hf_hub:xiangjx/musk", self.model, "model|module", "")
        self.model.to(device=self.device, dtype=torch.float16)
        self.model.eval()

        self.preprocess = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(384, interpolation=3, antialias=True),  # type: ignore[reportArgumentType]
                torchvision.transforms.CenterCrop((384, 384)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    mean=IMAGENET_INCEPTION_MEAN,
                    std=IMAGENET_INCEPTION_STD,
                ),
            ],
        )
        self.patch_embedding_dim = 2_048
        self.patch_text_retrieval_embedding_dim = 1_024
        self.text_embedding_dim = 1_024
        self.text_tokenizer = XLMRobertaTokenizer(
            Path(__file__).parent / "musk_tokenizer.spm",
        )

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        get_zeroshot_embeddings: bool = True,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Create the embeddings tensor(s) for storing results."""
        feature_embeddings = torch.empty(
            (n_patches, self.patch_embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        if get_zeroshot_embeddings:
            zeroshot_embeddings = torch.empty(
                (n_patches, self.patch_text_retrieval_embedding_dim),
                device=self.device,
                dtype=torch.float16,
            )
            return feature_embeddings, zeroshot_embeddings
        return feature_embeddings, None

    def _run_inference(
        self,
        batch_tensor: torch.Tensor,
        get_zeroshot_embeddings: bool = True,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Run model inference on preprocessed batch."""
        # Model outputs a tuple of (patch embeddings, text embeddings)
        # NOTE: There is a discrepancy between the flags specified in the example
        # on https://huggingface.co/xiangjx/musk and the demo notebook
        # https://github.com/lilab-stanford/MUSK/blob/main/demo.ipynb.
        # The HF example specifies out_norm=False and return_global=True, while the
        # demo notebook specifies out_norm=True and does not specify return_global.
        # The logic implemented is based on the demo notebook, as the HF example
        # instructs the user to refer to the demo notebook for a full implementation.
        feature_embeddings, _ = self.model(
            image=batch_tensor,
            with_head=False,
            out_norm=True,
            ms_aug=True,
        )

        if get_zeroshot_embeddings:
            # Flags set based on multimodal retrieval example in
            # https://github.com/lilab-stanford/MUSK/blob/main/demo.ipynb
            zeroshot_embeddings, _ = self.model(
                image=batch_tensor,
                with_head=True,
                out_norm=True,
            )
            return feature_embeddings, zeroshot_embeddings

        return feature_embeddings, None

    def _extract_embeddings(
        self,
        model_output: tuple[torch.Tensor, torch.Tensor | None],
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Extract embeddings from model output."""
        return model_output

    def preprocess_input_tile(self) -> Callable[[Image.Image], torch.Tensor]:
        """Per-image CPU-only preprocessing callable for Dataset workers."""
        return self.preprocess

    def prepare_batch_for_device(self, batch: Any) -> Any:  # noqa: ANN401
        """Move batch to device with expected input dtype for MUSK (float16)."""
        if isinstance(batch, torch.Tensor):
            return batch.to(self.device, dtype=torch.float16, non_blocking=True)
        if isinstance(batch, dict):
            out: dict[str, torch.Tensor] = {}
            for k, v in batch.items():
                if isinstance(v, torch.Tensor):
                    out[k] = v.to(self.device, dtype=torch.float16, non_blocking=True)
                else:
                    out[k] = v
            return out
        return batch

    def embed_text(self, text_descriptors: list[str]) -> torch.Tensor:
        """Embed the text."""
        # From https://github.com/lilab-stanford/MUSK/blob/main/demo.ipynb
        text_ids = []
        paddings = []
        for txt in text_descriptors:
            txt_ids, pad = utils.xlm_tokenizer(txt, self.text_tokenizer, max_len=100)
            text_ids.append(torch.tensor(txt_ids).unsqueeze(0))
            paddings.append(torch.tensor(pad).unsqueeze(0))

        text_ids = torch.cat(text_ids)
        paddings = torch.cat(paddings)

        with torch.inference_mode():
            return self.model(
                text_description=text_ids.to(self.device),
                padding_mask=paddings.to(self.device),
                with_head=True,
                out_norm=True,
            )[1]  # 0 corresponds to image embeddings

    @classmethod
    def scale_zeroshot_classification_logits(cls, logits: np.ndarray) -> np.ndarray:
        # It appears that logits should be scaled before softmax in zero-shot classification.
        # See https://github.com/lilab-stanford/MUSK/blob/e1699c27687f44bbf6d4adfcbb2abe89795d347f/benchmarks/clip_benchmark/metrics/zeroshot_classification.py#L204
        # Likely based on https://github.com/openai/CLIP/issues/48.
        return logits * 100
