"""CONCH foundation model.

@article{lu_visual-language_2024,
        title = {A visual-language foundation model for computational pathology},
        url = {https://www.nature.com/articles/s41591-024-02856-4},
        doi = {10.1038/s41591-024-02856-4},
        journal = {Nature Medicine},
        author = {Lu, Ming Y. and Chen, Bowen and Williamson, Drew F. K. and Chen, Richard J. and
                  Liang, Ivy and Ding, Tong and Jaume, Guillaume and Odintsov, Igor and Le, Long Phi
                  and Gerber, Georg and Parwani, Anil V. and Zhang, Andrew and Mahmood, Faisal},
        month = mar,
        year = {2024},
}
https://github.com/mahmoodlab/CONCH
https://huggingface.co/MahmoodLab/CONCH
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import torch
from conch.open_clip_custom import create_model_from_pretrained, get_tokenizer, tokenize

from pathfmtools.embedding_models.embedding_model import EmbeddingModel
from pathfmtools.embedding_models.registry import register_model

if TYPE_CHECKING:
    from collections.abc import Callable

    import numpy as np
    from PIL import Image

logger = logging.getLogger(__name__)


@register_model(
    "conch",
    embedding_dim=512,
    zeroshot_dim=512,
    supports_zeroshot=True,
    supports_text=True,
)
class CONCHModel(EmbeddingModel):
    """CONCH model."""

    NAME = "conch"
    EXPECTED_MAGNIFICATION = 20
    # From the paper:
    # "For all experiments, we standardized the image input size to 224 X 224"
    # The demo notebook https://github.com/mahmoodlab/CONCH/blob/main/notebooks/basics_usage.ipynb
    # states that "By default, the model preprocessor uses 448 x 448 as the input size", but this is
    # referring to the resized input that is fed into the model, not the original patch size
    # (see https://github.com/mahmoodlab/CONCH/blob/main/conch/open_clip_custom/model_configs/conch_ViT-B-16.json)
    # and statement in paper: "All images were resized to 448 X 448 for both training and
    # inference".
    EXPECTED_PATCH_SIZE = 224

    SUPPORTS_TEXT = True
    SUPPORTS_ZEROSHOT = True
    POOLING_RULE = "global"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.model, self.preprocess = create_model_from_pretrained(
            "conch_ViT-B-16",
            "hf_hub:MahmoodLab/conch",
        )  # pyright: ignore[reportGeneralTypeIssues]
        self.model.eval()
        self.model.to(self.device)
        # This is the dimension of the embedding that is computed by the CONCH model.
        # It is used to create an empty tensor of the appropriate shape for storing
        # embeddings as they are computed.
        self.embedding_dim = 512
        self.tokenizer = get_tokenizer()

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        get_zeroshot_embeddings: bool = True,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Create the embeddings tensor(s) for storing results."""
        embeddings = torch.empty(
            (n_patches, self.embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        if get_zeroshot_embeddings:
            zeroshot_embeddings = torch.empty(
                (n_patches, self.embedding_dim),
                device=self.device,
                dtype=torch.float16,
            )
            return embeddings, zeroshot_embeddings
        return embeddings, None

    def _preprocess_batch(self, pil_images: list[Image.Image]) -> torch.Tensor:
        """Preprocess a batch of PIL images for the model."""
        return torch.stack(
            [
                self.preprocess(p).to(self.device)  # pyright: ignore[reportAssignmentType,reportAttributeAccessIssue]
                for p in pil_images
            ],
        )

    def preprocess_input_tile(self) -> Callable[[Image.Image], torch.Tensor]:
        """Per-image CPU-only preprocessing callable for Dataset workers."""
        return self.preprocess

    def _run_inference(
        self,
        batch_tensor: torch.Tensor,
        get_zeroshot_embeddings: bool = True,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Run model inference on preprocessed batch."""
        main_embeddings = self.model.encode_image(
            batch_tensor,
            proj_contrast=False,
            normalize=False,
        )

        if get_zeroshot_embeddings:
            zeroshot_embeddings = self.model.encode_image(
                batch_tensor,
                proj_contrast=True,
                normalize=True,
            )
            return main_embeddings, zeroshot_embeddings

        return main_embeddings, None

    def _extract_embeddings(
        self,
        model_output: tuple[torch.Tensor, torch.Tensor | None],
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Extract embeddings from model output."""
        return model_output

    def embed_text(self, text_descriptors: list[str]) -> torch.Tensor:
        """Embed the text."""
        text_tokens = tokenize(texts=text_descriptors, tokenizer=self.tokenizer).to(self.device)
        with torch.inference_mode():
            return self.model.encode_text(text_tokens, normalize=True)

    @classmethod
    def scale_zeroshot_classification_logits(cls, logits: np.ndarray) -> np.ndarray:
        # It appears that logits should be scaled before softmax in zero-shot classification.
        # See https://github.com/mahmoodlab/CONCH/blob/c141475f5fa83891de67aa2d79f9fa74e232d2aa/conch/open_clip_custom/transformer.py#L93
        # Likely based on https://github.com/openai/CLIP/issues/48.
        return logits * 100
