"""H-optimus foundation models.

H-optimus: Histopathology Vision Foundation Models
https://huggingface.co/bioptimus/H-optimus-0
https://huggingface.co/bioptimus/H-optimus-1
"""

from __future__ import annotations

import logging
from abc import ABC
from typing import TYPE_CHECKING

import timm
import torch
from torchvision import transforms

from pathfmtools.embedding_models.embedding_model import EmbeddingModel
from pathfmtools.embedding_models.registry import register_model

if TYPE_CHECKING:
    from collections.abc import Callable

    import numpy as np
    from PIL import Image

logger = logging.getLogger(__name__)


def _supports_bf16(device: torch.device) -> bool:
    """Return True if the current device supports bf16 autocast."""
    if device.type != "cuda" or not torch.cuda.is_available():
        return False
    try:
        major, _ = torch.cuda.get_device_capability(device)
    except Exception:
        return False
    return major >= 8  # Ampere+


class HOptimusModel(EmbeddingModel, ABC):
    """H-optimus model."""

    model: torch.nn.Module
    transform: Callable[[Image.Image], torch.Tensor]
    embedding_dim: int

    def __init__(self, device: torch.device) -> None:
        super().__init__(device)

    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Create the embeddings tensor(s) for storing results."""
        embeddings = torch.empty(
            (n_patches, self.embedding_dim),
            device=self.device,
            dtype=torch.float16,
        )
        return embeddings, None

    def _run_inference(
        self,
        batch_tensor: torch.Tensor,
        **kwargs,
    ) -> torch.Tensor:
        """Run model inference on preprocessed batch."""
        # HF repo recommends AMP; use autocast on CUDA only.
        if self.device.type == "cuda":
            dtype = torch.bfloat16 if _supports_bf16(self.device) else torch.float16
            with torch.autocast(device_type="cuda", dtype=dtype), torch.inference_mode():
                return self.model(batch_tensor)
        else:
            with torch.inference_mode():
                return self.model(batch_tensor)

    def _extract_embeddings(
        self,
        model_output: torch.Tensor,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Extract embeddings from model output."""
        return model_output, None

    def preprocess_input_tile(self) -> Callable[[Image.Image], torch.Tensor]:
        """Per-image CPU-only preprocessing callable for Dataset workers."""
        return self.transform

    def embed_text(self, text_descriptors: list[str]) -> torch.Tensor:
        """Embed the text."""
        msg = "Text embedding is not supported for H-optimus-0"
        logger.exception(msg)
        raise NotImplementedError(msg)

    @classmethod
    def scale_zeroshot_classification_logits(cls, logits: np.ndarray) -> np.ndarray:
        """Scale the logits for zero-shot classification."""
        msg = "Zero-shot classification is not supported for H-optimus-0"
        logger.exception(msg)
        raise NotImplementedError(msg)


@register_model(
    "h-optimus-0",
    embedding_dim=1536,
    supports_zeroshot=False,
    supports_text=False,
)
class HOptimus0Model(HOptimusModel):
    """H-optimus-0 model."""

    NAME = "h-optimus-0"
    # From https://huggingface.co/bioptimus/H-optimus-0:
    # "H-optimus-0 expects images of size 224x224 that were extracted at 0.5 microns per pixel."
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "global"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.model = timm.create_model(
            "hf-hub:bioptimus/H-optimus-0",
            pretrained=True,
            init_values=1e-5,
            dynamic_img_size=False,
        )
        self.model.eval()
        self.model.to(self.device)

        # Create image transform with model card specified normalization
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.707, 0.579, 0.704),
                    std=(0.212, 0.230, 0.178),
                ),
            ],
        )

        self.embedding_dim = 1536


@register_model(
    "h-optimus-1",
    embedding_dim=1536,
    supports_zeroshot=False,
    supports_text=False,
)
class HOptimus1Model(HOptimusModel):
    """H-optimus-1 model."""

    NAME = "h-optimus-1"
    # From https://huggingface.co/bioptimus/H-optimus-1
    # "H-optimus-1 expects images of size 224x224 that were extracted at 0.5 microns per pixel."
    EXPECTED_MAGNIFICATION = 20
    EXPECTED_PATCH_SIZE = 224
    SUPPORTS_TEXT = False
    SUPPORTS_ZEROSHOT = False
    POOLING_RULE = "global"

    def __init__(self, device: torch.device) -> None:
        """Initialize the model."""
        super().__init__(device)
        self.model = timm.create_model(
            "hf-hub:bioptimus/H-optimus-1",
            pretrained=True,
            init_values=1e-5,
            dynamic_img_size=False,
        )
        self.model.eval()
        self.model.to(self.device)

        # Create image transform with model card specified normalization
        self.transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.707, 0.579, 0.704),
                    std=(0.212, 0.230, 0.178),
                ),
            ],
        )

        self.embedding_dim = 1536
