"""Zero-shot classification using foundation VLMs."""

from __future__ import annotations

import contextlib
import hashlib
import logging
from typing import TYPE_CHECKING

import h5py
import numpy as np
from scipy.special import softmax

from pathfmtools.embedding_models import get_embedding_model
from pathfmtools.io.schema import SCHEMA_VERSION
from pathfmtools.io.schema import StoreKeys as SK
from pathfmtools.utils.model_id import canon_model_id

if TYPE_CHECKING:
    from pathlib import Path

    import torch

    from pathfmtools.image import Slide

logger = logging.getLogger(__name__)


class ZeroShotPatchClassifier:
    """Performs zero-shot classification on patches using foundation VLMs."""

    def __init__(
        self,
        text_embedding_cache_fpath: Path | None = None,
    ) -> None:
        """Initialize the zero-shot classifier.

        Args:
            text_embedding_cache_fpath (Path | None, optional): The path to a text embedding file.
                This class will check the embedding file for existing embeddings before computing
                new ones, and will write new embeddings to the file. Defaults to None.

        """
        self._text_embeddings = {}

        self._text_embedding_cache_fpath = text_embedding_cache_fpath
        self._load_embeddings()  # Load cached text embeddings stored in the file (if any).

        self._models = {}

    def _load_embeddings(self) -> None:
        """Load cached text embeddings from the cache file (if any)."""
        if self._text_embedding_cache_fpath is None:
            msg = "Text embedding cache file path is not set."
            logger.warning(msg)
            return
        elif not self._text_embedding_cache_fpath.exists():
            msg = f"Text embedding cache file {self._text_embedding_cache_fpath} does not exist."
            logger.warning(msg)
            return

        with h5py.File(self._text_embedding_cache_fpath, "r") as f:
            for model_group in f:
                if isinstance(f[model_group], h5py.Group):
                    idx = f[model_group].get("index", None)
                    reverse = {}
                    if isinstance(idx, h5py.Group):
                        for h in idx:
                            reverse[h] = idx[h][()].decode("utf-8")
                    self._text_embeddings.setdefault(model_group, {})
                    for ds in f[model_group]:
                        if ds == "index":
                            continue
                        text_key = reverse.get(ds, ds)
                        self._text_embeddings[model_group][text_key] = f[model_group][ds][()]

    def _write_embedding(self, model_name: str, text_val: str, embedding: np.ndarray) -> None:
        if self._text_embedding_cache_fpath is None:
            return
        mode = "a" if self._text_embedding_cache_fpath.exists() else "w"
        with h5py.File(self._text_embedding_cache_fpath, mode) as f:
            if mode == "w":
                f.attrs["schema_version"] = SCHEMA_VERSION
            grp = f.require_group(canon_model_id(model_name))
            ds_name = _ds_name_for_text(text_val)
            grp[ds_name] = embedding
            # Maintain reverse index: hash->original text
            idx = grp.require_group("index")
            idx[ds_name] = np.string_(text_val)

    def get_text_embedding(
        self,
        model_name: str,
        text_val: str,
        device: torch.device,
    ) -> np.ndarray:
        """Get the text embedding for a given text value.

        Args:
            model_name (str): The name of the model to use.
            text_val (str): The text value to get the embedding for.
            device (torch.device): The device to use for the embedding.

        Returns:
            np.ndarray: The text embedding.

        """
        if model_name not in self._text_embeddings:
            self._text_embeddings[model_name] = {}

        if text_val not in self._text_embeddings[model_name]:
            if model_name not in self._models:
                self._models[model_name] = get_embedding_model(model_name)(device=device)

            text_embedding = (
                self._models[model_name]
                .embed_text(text_descriptors=[text_val])
                .detach()
                .cpu()
                .numpy()
            )
            self._text_embeddings[model_name][text_val] = text_embedding
            self._write_embedding(model_name, text_val, text_embedding)

        return self._text_embeddings[model_name][text_val]

    def classify(
        self,
        model_name: str,
        classes: list[str],
        device: torch.device,
        slide: Slide | None = None,
        patch_embedding: np.ndarray | None = None,
    ) -> dict[str, np.ndarray]:
        """Classify a patch or slide using a zero-shot classifier.

        Args:
            model_name (str): The name of the model to use.
            classes (list[str]): The classes to classify the patch or slide into.
            device (torch.device): The device to use for the classification.
            slide (Slide | None, optional): The slide to classify. Defaults to None.
            patch_embedding (np.ndarray | None, optional): The patch to classify. Defaults to
                None.

        Returns:
            dict[str, np.ndarray]: A dictionary containing the logits, probabilities, and
                class predictions.

        """
        out_dict = {}
        logit_dict = self.get_patch_text_similarity_logits(
            model_name=model_name,
            text_list=classes,
            slide=slide,
            patch_embedding=patch_embedding,
            device=device,
        )

        out_dict["logits"] = {k: v.flatten() for k, v in logit_dict.items()}
        logit_arr = np.stack([logit_dict[class_name] for class_name in classes], axis=1)
        prob_arr = softmax(logit_arr, axis=1)
        out_dict["probabilities"] = {
            class_name: prob_arr[:, class_ix].flatten()
            for class_ix, class_name in enumerate(classes)
        }
        out_dict["class_predictions"] = np.argmax(prob_arr, axis=1).flatten()

        return out_dict

    def get_patch_text_similarity_logits(
        self,
        model_name: str,
        text_list: list[str],
        device: torch.device,
        slide: Slide | None = None,
        patch_embedding: np.ndarray | None = None,
    ) -> dict[str, np.ndarray]:
        """Get text similarity logits for the slide.

        The main use case for this method is to compute patch-text similarity scores for all
        processed patches in a slide (via `slide`). However, this method also supports
        computing logits for a given patch embedding (via `patch_embedding`), which is useful for
        patch embeddings that are not directly produced by the model (e.g. cluster centroids when
        applying K-means clustering to patch embeddings produced by a patch embedding model).

        Args:
            model_name (str): The name of the model to use.
            text_list (list[str]): A list of text values for which patch embedding similarity scores
                will be computed.
            device (torch.device): The device to use for running inference.
            slide (Slide | None, optional): The slide corresponding to the embeddings for which
                patch-text similarity scores will be computed. Defaults to None.
            patch_embedding (np.ndarray | None, optional): The patch embedding for which text
                similarity scores will be computed. Defaults to None.

        Raises:
            ValueError: If both `slide` and `patch_embedding` are provided.
            ValueError: If the text embeddings and patch embeddings have different dimensions.
            ValueError: If the patch embedding magnitudes are not 1.
            ValueError: If the text embedding magnitudes are not 1.

        Returns:
            dict[str, np.ndarray]: A dictionary containing the logits for each text value.

        """
        # Exactly one of slide or patch_embedding must be provided
        if not ((slide is not None) ^ (patch_embedding is not None)):
            msg = "Exactly one of (slide, patch_embedding) must be provided"
            logger.exception(msg)
            raise ValueError(msg)

        if slide is not None:
            try:
                zeroshot_embeddings: np.ndarray = slide.store.read_embeddings(
                    slide_id=slide.id_,
                    model_id=model_name,
                    kind=SK.TILE_ZEROSHOT_EMBEDDINGS,
                )
            except KeyError as e:  # pragma: no cover - bubbled as ValueError with context
                msg = (
                    f"Zero-shot embeddings not found for slide '{slide.id_}' and model "
                    f"'{model_name}'. Ensure embeddings are computed and saved."
                )
                logger.exception(msg)
                raise ValueError(msg) from e
        else:
            zeroshot_embeddings: np.ndarray = patch_embedding  # type: ignore
            if len(zeroshot_embeddings.shape) != 2:
                if len(zeroshot_embeddings.shape) == 1:
                    zeroshot_embeddings = zeroshot_embeddings[np.newaxis, :]
                else:
                    msg = f"Patch embedding must be 1D or 2D, but got shape {zeroshot_embeddings.shape}"
                    logger.exception(msg)
                    raise ValueError(msg)

        model_cls = get_embedding_model(model_name)

        logit_dict = {}
        for text_val in text_list:
            logit_dict[text_val] = {}
            text_embedding = self.get_text_embedding(
                model_name=model_name,
                text_val=text_val,
                device=device,
            )
            if text_embedding.shape[1] != zeroshot_embeddings.shape[1]:
                msg = (
                    f"Text embeddings for model {model_name} must have the same dimension as"
                    f"patch embeddings, got {text_embedding.shape[1]} (text) and"
                    f"{zeroshot_embeddings.shape[1]} (patch)"
                )
                logger.exception(msg)
                raise ValueError(msg)

            z = np.array(zeroshot_embeddings, copy=True)
            t = np.array(text_embedding, copy=True)

            patch_embedding_magnitudes = np.linalg.norm(z, axis=1)
            text_embedding_magnitude = float(np.linalg.norm(t))
            if not np.allclose(patch_embedding_magnitudes, 1.0, atol=1e-3):
                embedding_magnitude_range = (
                    float(patch_embedding_magnitudes.min()),
                    float(patch_embedding_magnitudes.max()),
                )
                logger.warning(
                    "Patch embedding magnitudes are not 1. Range: %s",
                    embedding_magnitude_range,
                )
            if (patch_embedding_magnitudes == 0.0).any():
                msg = "Patch embedding magnitudes are 0."
                logger.exception(msg)
                raise ValueError(msg)
            z = z / patch_embedding_magnitudes[:, np.newaxis]

            if not np.isclose(text_embedding_magnitude, 1.0, atol=1e-3):
                logger.warning(
                    "Text embedding magnitude is not 1 (%.6f). Normalizing.",
                    text_embedding_magnitude,
                )
            if text_embedding_magnitude == 0.0:
                msg = "Text embedding magnitude is 0."
                logger.exception(msg)
                raise ValueError(msg)
            t = t / text_embedding_magnitude

            # (n_patches, embedding_dim) @ (embedding_dim, n_classes) -> (n_patches, n_classes)
            similarity_arr = z @ t.T
            logger.info(
                "Similarity range (before scaling): [%s, %s]",
                similarity_arr.min(),
                similarity_arr.max(),
            )
            # Some models do not support zero-shot scaling; keep raw cosine similarity in that case.
            with contextlib.suppress(NotImplementedError):
                similarity_arr = model_cls.scale_zeroshot_classification_logits(
                    logits=similarity_arr,
                )
            logger.info(
                "Similarity range (after scaling): [%s, %s]",
                similarity_arr.min(),
                similarity_arr.max(),
            )
            # Note that many foundation models require the similarity scores to be scaled before
            # they are considered logits.
            logit_dict[text_val] = similarity_arr

        return logit_dict


def _ds_name_for_text(text_val: str) -> str:
    return hashlib.sha1(text_val.encode("utf-8")).hexdigest()
