"""Abstract base class for embedding models.

Responsibilities:
- Provide a uniform embedding loop across patch models.
- Implement rescale/resize based on slide magnification and a model's expected
  magnification/patch size.
"""

from __future__ import annotations

import logging
import math
import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING, Any

import cv2
import torch
from rich.progress import track
from torch.utils.data import DataLoader

from pathfmtools.utils.torch import (
    TileDataset,
    TransformChain,
    collate_dict_or_tensor,
    cv2_resize_pil,
)

if TYPE_CHECKING:
    from collections.abc import Callable

    import numpy as np
    from PIL import Image

    from pathfmtools.image.tile_index import TileIndex
    from pathfmtools.io.slide_data_store import SlideDataStore
    from pathfmtools.io.slide_reader import SlideReader

logger = logging.getLogger(__name__)


class EmbeddingModel(ABC):
    """Abstract base class for all patch embedding models."""

    NAME = "N/A"
    EXPECTED_MAGNIFICATION = None
    EXPECTED_PATCH_SIZE = None
    SUPPORTS_TEXT: bool = False
    SUPPORTS_ZEROSHOT: bool = False
    POOLING_RULE: str = "global"  # e.g., 'cls', 'mean', 'cls+mean', 'global'

    def __init__(
        self,
        device: torch.device,
    ) -> None:
        """Initialize the model."""
        self.device = device
        # Warn once per instance when rescaling is applied
        self._issued_rescale_warning = False
        logger.info("Using device: %s", self.device)

    @abstractmethod
    def _create_embeddings_tensors(
        self,
        n_patches: int,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Create the embeddings tensor(s) for storing results.

        Returns:
            (main_embeddings, additional_embeddings_or_None)

        """

    @abstractmethod
    def _run_inference(
        self,
        batch_tensor: torch.Tensor,
        **kwargs,
    ) -> Any:  # noqa: ANN401
        """Run model inference on preprocessed batch.

        Args:
            batch_tensor: Preprocessed batch tensor
            **kwargs: Additional model-specific arguments

        Returns:
            Model output(s)

        """

    @abstractmethod
    def _extract_embeddings(
        self,
        model_output: Any,  # noqa: ANN401
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Extract embeddings from model output.

        Args:
            model_output: Output from _run_inference
            **kwargs: Additional arguments

        Returns:
            (main_embeddings, additional_embeddings_or_None)

        """

    @abstractmethod
    def preprocess_input_tile(self) -> Callable[[Image.Image], Any]:
        """Return a CPU-only, picklable per-image preprocessing callable.

        Implementations must not move tensors to device or change dtype; output must be either a
        torch.Tensor (CHW, float32) or a dict[str, torch.Tensor] with consistent keys across items.
        """

    def build_dataset_transform(
        self,
        slide_patch_size: int,
        slide_magnification: float,
        *,
        auto_rescale: bool = True,
    ) -> Callable[[Image.Image], Any]:
        """Build a CPU-only per-image transform for the Dataset.

        Composes optional cv2 resize (to match model receptive field) followed by the model's
        per-image preprocess callable.
        """
        steps: list[Callable[[Any], Any]] = []

        if auto_rescale:
            scale_info = self._compute_texture_scale(slide_patch_size, slide_magnification)
            if scale_info is not None:
                s, target_size = scale_info
                if s != 1.0:
                    if not self._issued_rescale_warning:
                        logger.warning(
                            "Applying texture-scale resampling for %s: slide_patch=%s, slide_mag=%s"
                            ", expected_mag=%s, expected_patch=%s, scale=%.4f, target_size=%s",
                            self.NAME,
                            slide_patch_size,
                            slide_magnification,
                            self.EXPECTED_MAGNIFICATION,
                            self.EXPECTED_PATCH_SIZE,
                            s,
                            target_size,
                        )
                        self._issued_rescale_warning = True
                    steps.append(partial(cv2_resize_pil, target_size=(target_size, target_size)))
        else:
            # Optionally log a one-time mismatch warning if expectations are defined and sizes differ
            scale_info = self._compute_texture_scale(slide_patch_size, slide_magnification)
            if scale_info is not None:
                s, target_size = scale_info
                if s != 1.0 and not self._issued_rescale_warning:
                    logger.warning(
                        "Receptive field mismatch for %s: slide_patch=%s, slide_mag=%s, expected_mag=%s, "
                        "expected_patch=%s, scale=%.4f, target_size=%s",
                        self.NAME,
                        slide_patch_size,
                        slide_magnification,
                        self.EXPECTED_MAGNIFICATION,
                        self.EXPECTED_PATCH_SIZE,
                        s,
                        target_size,
                    )
                    self._issued_rescale_warning = True

        steps.append(self.preprocess_input_tile())
        return TransformChain(steps)

    def _compute_texture_scale(
        self,
        slide_patch_size: int,
        slide_magnification: float,
    ) -> tuple[float, int] | None:
        """Compute texture scaling factor and target size for resampling.

        Returns (scale, target_size) or None if expectations are not defined.
        """
        if self.EXPECTED_MAGNIFICATION is None or self.EXPECTED_PATCH_SIZE is None:
            return None
        # Scale to match microns-per-pixel
        # s < 1 → downscale; s > 1 → upscale
        s = float(self.EXPECTED_MAGNIFICATION) / float(slide_magnification)
        target_size = round(slide_patch_size * s)
        if target_size < 1:
            msg = "Target size is less than 1"
            raise ValueError(msg)
        return s, target_size

    def _validate_receptive_field(
        self,
        slide_patch_size: int,
        slide_magnification: float,
    ) -> dict:
        """Validate that receptive field of the patch is compatible with model's expectations.

        Args:
            slide_patch_size: Size of patches extracted from the slide (in pixels)
            slide_magnification: Magnification level at which patches were captured

        Returns:
            Dictionary containing:
                - is_compatible: Boolean indicating if receptive fields match
                - patch_receptive_field: Calculated receptive field for the patch
                - model_receptive_field: Expected receptive field for the model
                - target_size: Model's expected patch size
                - reason: Optional reason for skipping validation

        """
        # Skip validation if model expectations are not defined
        if self.EXPECTED_MAGNIFICATION is None or self.EXPECTED_PATCH_SIZE is None:
            return {"is_compatible": True, "reason": "skip_validation"}

        # Calculate receptive fields
        slide_rf = slide_patch_size / slide_magnification
        model_rf = self.EXPECTED_PATCH_SIZE / self.EXPECTED_MAGNIFICATION

        # Check compatibility using near-equality comparison
        is_compatible = math.isclose(slide_rf, model_rf)

        return {
            "is_compatible": is_compatible,
            "slide_receptive_field": slide_rf,
            "model_receptive_field": model_rf,
            "target_size": self.EXPECTED_PATCH_SIZE,
        }

    def embed_tiles(
        self,
        batch_size: int,
        tile_index: TileIndex,
        slide_reader: SlideReader,
        slide_data_store: SlideDataStore,
        verbose: bool = True,
        skip_feature_embeddings: bool = False,
        skip_zeroshot_embeddings: bool = False,
        auto_rescale: bool = True,
        num_workers: int | None = None,
        pin_memory: bool = True,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Embed the patches."""
        # Resolve slide identifier from reader path to locate per-slide HDF5
        slide_id = slide_reader.slide_path.stem
        # Number of tiles comes from TileIndex (foreground mask already applied)
        n_patches = tile_index.n_foreground_tiles

        # Build dataset transform (cv2 resize if needed, then per-image preprocess)
        dataset_transform = self.build_dataset_transform(
            slide_patch_size=int(tile_index.tile_size),
            slide_magnification=float(slide_reader.magnification),
            auto_rescale=auto_rescale,
        )

        # Dataset streams tiles directly from the per-slide HDF5 with worker-local handles
        dataset = TileDataset(
            tile_index=tile_index,
            store=slide_data_store,
            slide_id=slide_id,
            transform=dataset_transform,
        )

        # DataLoader for efficient batching
        if num_workers is None:
            num_workers = os.cpu_count() or 4
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            collate_fn=collate_dict_or_tensor,
        )

        # Allocate output tensors
        get_zeroshot = bool(self.SUPPORTS_ZEROSHOT) and not bool(skip_zeroshot_embeddings)
        feature_embeddings, zeroshot_embeddings = self._create_embeddings_tensors(
            n_patches=n_patches,
            get_zeroshot_embeddings=get_zeroshot,
            **kwargs,
        )

        # Iterate and embed
        processed_patches = 0
        description = f"Embedding slide {slide_id} ({self.NAME})"
        iterator = track(loader, description=description, style="cyan") if verbose else loader

        for batch_data in iterator:
            batch_tensor_or_dict = self.prepare_batch_for_device(batch_data)

            with torch.inference_mode():
                model_output = self._run_inference(
                    batch_tensor_or_dict,
                    get_zeroshot_embeddings=get_zeroshot,
                    **kwargs,
                )

                batch_feature_embeddings, batch_zeroshot_embeddings = self._extract_embeddings(
                    model_output,
                    **kwargs,
                )

                # Determine batch length for slicing
                if isinstance(batch_tensor_or_dict, torch.Tensor):
                    batch_len = int(batch_tensor_or_dict.shape[0])
                elif isinstance(batch_tensor_or_dict, dict):
                    first_key = next(iter(batch_tensor_or_dict))
                    batch_len = int(batch_tensor_or_dict[first_key].shape[0])
                else:
                    msg = f"Unsupported batch type: {type(batch_tensor_or_dict)}"
                    logger.exception(msg)
                    raise TypeError(msg)

                batch_start_ix = processed_patches
                batch_end_ix = processed_patches + batch_len

                feature_embeddings[batch_start_ix:batch_end_ix, :] = batch_feature_embeddings
                if zeroshot_embeddings is not None and batch_zeroshot_embeddings is not None:
                    zeroshot_embeddings[batch_start_ix:batch_end_ix, :] = batch_zeroshot_embeddings

                processed_patches = batch_end_ix

        feature_embeddings = feature_embeddings.to(torch.float16)
        if zeroshot_embeddings is not None:
            zeroshot_embeddings = zeroshot_embeddings.to(torch.float16)

        return feature_embeddings, zeroshot_embeddings

    def embed_text(
        self,
        text_descriptors: list[str],
    ) -> torch.Tensor:
        """Embed the text."""
        msg = f"Text embedding is not supported for {self.NAME}"
        logger.exception(msg)
        raise NotImplementedError(msg)

    @classmethod
    def scale_zeroshot_classification_logits(
        cls,
        logits: np.ndarray,  # noqa: ARG003
    ) -> np.ndarray:
        """Scale the logits for zero-shot classification."""
        msg = f"Zero-shot classification is not supported for {cls.NAME}"
        logger.exception(msg)
        raise NotImplementedError(msg)

    def rescale(
        self,
        pixels: np.ndarray,
        target_size: tuple[int, int],
    ) -> np.ndarray:
        """Rescale the patch array to the target size."""
        # From CV2 docs: "To shrink an image, it will generally look best with INTER_AREA
        # interpolation, whereas to enlarge an image, it will generally look best with INTER_CUBIC
        # (slow) or INTER_LINEAR (faster but still looks OK)."
        current_size = pixels.shape[0]  # Assumes square patches
        scaling_factor = target_size[0] / current_size
        interpolation_method = cv2.INTER_AREA if scaling_factor < 1 else cv2.INTER_CUBIC
        return cv2.resize(
            src=pixels,
            dsize=target_size,
            interpolation=interpolation_method,
        )

    def prepare_batch_for_device(self, batch: Any) -> Any:  # noqa: ANN401
        """Move a collated batch (Tensor or dict[str, Tensor]) onto the model's device.

        Subclasses may override to adjust dtype requirements.
        """
        if isinstance(batch, torch.Tensor):
            return batch.to(self.device)
        if isinstance(batch, dict):
            return {
                k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()
            }
        return batch
