"""
BLIP model handler.

Extracts image/text features from BLIP-family models with a CLIPHandler-compatible interface.

Notes:
- TopK/Scope relies on image-vs-text feature similarity.
- Prefer retrieval/contrastive checkpoints (e.g., blip-itm-* / image-text-retrieval).
- Pure captioning checkpoints (e.g., blip-image-captioning-*) may not provide stable contrastive features;
  this handler raises a clearer error during loading in that case.

Compared to LAVIS `BlipITM`: this handler primarily uses the "ITC / dual-encoder" path
(equivalent to `match_head="itc"` in LAVIS) and computes similarity via `image_feat @ text_feat.T`.
It does not use `match_head="itm"` cross-attention-style ITM logits.
"""

from __future__ import annotations

import hashlib
import threading
from collections import OrderedDict
from typing import Any, Dict, List, Optional

import numpy as np
import torch
from PIL import Image

from config.settings import CLIP_CONFIG


def _resolve_device(device: Optional[str]) -> str:
    if not device or device == "auto":
        return "cuda" if torch.cuda.is_available() else "cpu"
    return device


class BLIPModelHandler:
    """BLIP model handler (singleton)."""

    _instances: Dict[str, "BLIPModelHandler"] = {}
    _lock = threading.Lock()

    def __new__(cls, model_path: str = None, device: str = None, **kwargs):
        model_path = model_path or CLIP_CONFIG["default_model_path"]
        device = _resolve_device(device or CLIP_CONFIG["default_device"])
        key = f"{model_path}#{device}"
        if key not in cls._instances:
            with cls._lock:
                if key not in cls._instances:
                    instance = super().__new__(cls)
                    cls._instances[key] = instance
        return cls._instances[key]

    def __init__(
        self,
        model_path: str = None,
        device: str = None,
        max_video_cache: int = None,
        max_text_cache: int = None,
        **kwargs,
    ):
        if hasattr(self, "_initialized"):
            return

        self.model_path = model_path or CLIP_CONFIG["default_model_path"]
        self.device = _resolve_device(device or CLIP_CONFIG["default_device"])
        self.max_video_cache = max_video_cache if max_video_cache is not None else CLIP_CONFIG["max_video_cache"]
        self.max_text_cache = max_text_cache if max_text_cache is not None else CLIP_CONFIG["max_text_cache"]
        self.max_frame_cache = int(CLIP_CONFIG.get("max_frame_cache", 0) or 0)

        self.model = None
        self.processor = None
        self.backend = None  # "retrieval" | "base"
        self.config = kwargs

        self.video_cache = OrderedDict()
        self.text_cache = OrderedDict()
        self.frame_cache = OrderedDict()

        self.cache_hits = {"text": 0, "video": 0, "frame": 0}
        self.cache_misses = {"text": 0, "video": 0, "frame": 0}

        self._initialized = True

    def load_model(self):
        """Load the BLIP model and processor."""
        if self.model is not None and self.processor is not None:
            return self.model, self.processor

        from transformers import BlipProcessor

        # Prefer retrieval/matching checkpoints (better for similarity).
        retrieval_exc = None
        model = None
        backend = None

        try:
            from transformers import BlipForImageTextRetrieval  # type: ignore

            model = BlipForImageTextRetrieval.from_pretrained(self.model_path)
            backend = "retrieval"
        except Exception as e:  # noqa: BLE001
            retrieval_exc = e

        if model is None:
            try:
                from transformers import BlipModel  # type: ignore

                model = BlipModel.from_pretrained(self.model_path)
                backend = "base"
            except Exception as e:  # noqa: BLE001
                msg = (
                    f"[BLIP] Failed to load {self.model_path} as a BLIP retrieval/base model.\n"
                    f"  - BlipForImageTextRetrieval failed: {repr(retrieval_exc)}\n"
                    f"  - BlipModel failed: {repr(e)}\n\n"
                    "Recommendation: use a retrieval/contrastive checkpoint (e.g. 'blip-itm-base-coco' / "
                    "'blip-itm-large-coco') instead of a pure captioning checkpoint (e.g. "
                    "'blip-image-captioning-base')."
                )
                raise RuntimeError(msg) from e

        self.processor = BlipProcessor.from_pretrained(self.model_path)
        self.model = model.to(self.device)
        self.model.eval()
        self.backend = backend

        if self.backend == "base":
            print(
                "[BLIP][WARN] Loaded BlipModel (base). It is typically not fine-tuned for image-text "
                "retrieval/matching, so similarity performance may be significantly worse than "
                "blip-itm / image-text-retrieval checkpoints."
            )
            if retrieval_exc is not None:
                print(f"[BLIP][WARN] BlipForImageTextRetrieval load error: {repr(retrieval_exc)}")

        print(f"[BLIP] Loaded ({self.backend}) from {self.model_path} on {self.device}")
        return self.model, self.processor

    def _lru_put_text(self, key: str, value: np.ndarray):
        if self.max_text_cache <= 0:
            return
        if key in self.text_cache:
            self.text_cache.move_to_end(key)
        else:
            self.text_cache[key] = value
            if len(self.text_cache) > self.max_text_cache:
                self.text_cache.popitem(last=False)

    def _lru_put_frame(self, key: str, value: np.ndarray):
        cap = self.max_frame_cache
        if cap <= 0:
            # Backward-compat: if max_frame_cache is unset, derive it from max_video_cache.
            cap = max(0, int(self.max_video_cache) * 50)
        if cap <= 0:
            return
        if key in self.frame_cache:
            self.frame_cache.move_to_end(key)
        else:
            self.frame_cache[key] = value
            if len(self.frame_cache) > cap:
                self.frame_cache.popitem(last=False)

    def _get_frame_key(self, video_path: str, frame_idx: int) -> str:
        video_hash = hashlib.md5(video_path.encode()).hexdigest()[:8]
        return f"{video_hash}_{frame_idx}"

    def _extract_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
        self.load_model()

        if hasattr(self.model, "get_image_features"):
            feats = self.model.get_image_features(pixel_values=pixel_values)  # type: ignore[attr-defined]
            return feats

        if hasattr(self.model, "vision_model"):
            vision_outputs = self.model.vision_model(pixel_values=pixel_values, return_dict=True)  # type: ignore[attr-defined]
            pooled = getattr(vision_outputs, "pooler_output", None)
            if pooled is None:
                pooled = vision_outputs.last_hidden_state[:, 0]
            if hasattr(self.model, "vision_proj"):
                pooled = self.model.vision_proj(pooled)  # type: ignore[attr-defined]
            return pooled

        raise RuntimeError(
            "[BLIP] The current model does not support extracting image features "
            "(missing get_image_features/vision_model)"
        )

    def _extract_text_features(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        self.load_model()

        if hasattr(self.model, "get_text_features"):
            feats = self.model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)  # type: ignore[attr-defined]
            return feats

        if hasattr(self.model, "text_encoder"):
            text_outputs = self.model.text_encoder(  # type: ignore[attr-defined]
                input_ids=input_ids, attention_mask=attention_mask, return_dict=True
            )
            pooled = getattr(text_outputs, "pooler_output", None)
            if pooled is None:
                pooled = text_outputs.last_hidden_state[:, 0]
            if hasattr(self.model, "text_proj"):
                pooled = self.model.text_proj(pooled)  # type: ignore[attr-defined]
            return pooled

        raise RuntimeError(
            "[BLIP] The current model does not support extracting text features "
            "(missing get_text_features/text_encoder)"
        )

    @torch.no_grad()
    def encode_images(
        self,
        frames: List[Image.Image],
        batch_size: int = 64,
        video_path: str = None,
        frame_indices: List[int] = None,
    ) -> np.ndarray:
        if not frames:
            return np.array([])

        self.load_model()

        # Frame-level cache (aligned with CLIPHandler behavior)
        if video_path and frame_indices and len(frames) == len(frame_indices):
            cached_features = []
            uncached_frames = []
            uncached_indices = []

            for i, (frame, frame_idx) in enumerate(zip(frames, frame_indices)):
                frame_key = self._get_frame_key(video_path, int(frame_idx))
                if frame_key in self.frame_cache:
                    self.frame_cache.move_to_end(frame_key)
                    cached_features.append((i, self.frame_cache[frame_key]))
                    self.cache_hits["frame"] += 1
                else:
                    uncached_frames.append((i, frame))
                    uncached_indices.append(int(frame_idx))
                    self.cache_misses["frame"] += 1

            new_features = []
            if uncached_frames:
                batch_frames = [f for _, f in uncached_frames]
                for start in range(0, len(batch_frames), batch_size):
                    batch = batch_frames[start : start + batch_size]
                    inputs = self.processor(images=batch, return_tensors="pt").to(self.device)
                    feats = self._extract_image_features(inputs["pixel_values"])
                    feats = feats / feats.norm(dim=-1, keepdim=True)
                    feats_np = feats.cpu().numpy()
                    for j, feature in enumerate(feats_np):
                        global_idx = start + j
                        if global_idx < len(uncached_frames):
                            original_idx, _ = uncached_frames[global_idx]
                            frame_idx = uncached_indices[global_idx]
                            frame_key = self._get_frame_key(video_path, frame_idx)
                            self._lru_put_frame(frame_key, feature)
                            new_features.append((original_idx, feature))

            all_indexed = cached_features + new_features
            all_indexed.sort(key=lambda x: x[0])
            return np.array([fv for _, fv in all_indexed])

        # No cache: batch encode directly
        all_features = []
        for start in range(0, len(frames), batch_size):
            batch = frames[start : start + batch_size]
            inputs = self.processor(images=batch, return_tensors="pt").to(self.device)
            feats = self._extract_image_features(inputs["pixel_values"])
            feats = feats / feats.norm(dim=-1, keepdim=True)
            all_features.append(feats.cpu().numpy())
        return np.vstack(all_features) if all_features else np.array([])

    @torch.no_grad()
    def encode_text(self, text: str) -> np.ndarray:
        if text in self.text_cache:
            self.text_cache.move_to_end(text)
            self.cache_hits["text"] += 1
            return self.text_cache[text]

        self.cache_misses["text"] += 1
        self.load_model()

        inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(self.device)
        feats = self._extract_text_features(inputs["input_ids"], inputs["attention_mask"])
        feats = feats / feats.norm(dim=-1, keepdim=True)
        feats_np = feats.cpu().numpy()
        self._lru_put_text(text, feats_np)
        return feats_np

    @torch.no_grad()
    def encode_texts(self, texts: List[str]) -> np.ndarray:
        if not texts:
            return np.array([])

        out: List[Optional[np.ndarray]] = [None] * len(texts)
        to_compute: List[str] = []
        text_to_indices: Dict[str, List[int]] = {}

        for i, text in enumerate(texts):
            if text in self.text_cache:
                self.text_cache.move_to_end(text)
                self.cache_hits["text"] += 1
                out[i] = self.text_cache[text]
                continue

            self.cache_misses["text"] += 1
            if text not in text_to_indices:
                to_compute.append(text)
                text_to_indices[text] = [i]
            else:
                text_to_indices[text].append(i)

        if to_compute:
            self.load_model()
            inputs = self.processor(text=to_compute, return_tensors="pt", padding=True, truncation=True).to(self.device)
            feats = self._extract_text_features(inputs["input_ids"], inputs["attention_mask"])
            feats = feats / feats.norm(dim=-1, keepdim=True)
            feats_np = feats.detach().cpu().numpy()  # (M, D)

            for j, text in enumerate(to_compute):
                feat = feats_np[j : j + 1]  # (1, D)
                self._lru_put_text(text, feat)
                for idx in text_to_indices.get(text, []):
                    out[idx] = feat

        for i, v in enumerate(out):
            if v is None:
                out[i] = self.encode_text(texts[i])

        return np.vstack(out)  # type: ignore[arg-type]

    def compute_similarity(self, image_features: np.ndarray, text_features: np.ndarray) -> np.ndarray:
        return image_features @ text_features.T

    def clear_text_cache(self):
        self.text_cache.clear()

    def clear_video_cache(self, video_key: str = None):
        if video_key is None:
            self.video_cache.clear()
        else:
            self.video_cache.pop(video_key, None)

    def clear_frame_cache(self, video_path: str = None):
        if video_path is None:
            self.frame_cache.clear()
        else:
            video_hash = hashlib.md5(video_path.encode()).hexdigest()[:8]
            keys_to_remove = [k for k in list(self.frame_cache.keys()) if k.startswith(video_hash)]
            for k in keys_to_remove:
                del self.frame_cache[k]

    def set_cache_limits(self, max_video_cache: int = None, max_text_cache: int = None):
        if max_video_cache is not None:
            self.max_video_cache = max_video_cache
        if max_text_cache is not None:
            self.max_text_cache = max_text_cache

    def get_cache_stats(self) -> Dict[str, Any]:
        return {
            "cache_sizes": {"text": len(self.text_cache), "video": len(self.video_cache), "frame": len(self.frame_cache)},
            "cache_hits": self.cache_hits.copy(),
            "cache_misses": self.cache_misses.copy(),
            "hit_rates": {
                cache_type: hits / max(1, hits + self.cache_misses[cache_type]) for cache_type, hits in self.cache_hits.items()
            },
        }

    def get_model_info(self) -> Dict[str, Any]:
        return {
            "backend": "blip",
            "model_path": self.model_path,
            "device": self.device,
            "impl": self.backend,
            "cache_config": {
                "max_video_cache": self.max_video_cache,
                "max_text_cache": self.max_text_cache,
                "max_frame_cache": self.max_frame_cache,
                "current_video_cache_size": len(self.video_cache),
                "current_text_cache_size": len(self.text_cache),
                "current_frame_cache_size": len(self.frame_cache),
            },
            "cache_stats": self.get_cache_stats(),
            "config": self.config,
        }
