# core/samplers/topk.py
"""
TopK sampler.
Selects Top-K frames by vision-text similarity (CLIP/BLIP).
"""

import os
import torch
import numpy as np
from PIL import Image
from typing import List, Optional, Dict
from collections import OrderedDict

from .base import FrameBasedSampler
from core.models.vision_factory import create_vision_text_handler
from config.settings import CLIP_CONFIG, TOPK_CONFIG


class CLIPTopKSampler(FrameBasedSampler):
    """
    TopK sampler (historically named CLIPTopKSampler).
    - Defaults to CLIP.
    - Switches to BLIP when clip_model_path points to a BLIP checkpoint (or backend=blip).
    """
    
    def __init__(self, clip_model_path: str = None, device: str = None, 
                 max_video_cache: int = None, max_text_cache: int = None, default_k: int = None, **kwargs):
        super().__init__(**kwargs)
        
        backend = kwargs.get("vision_backend") or kwargs.get("backend")
        # Use a unified vision-text similarity handler (singleton)
        self.vision_handler = create_vision_text_handler(
            model_path=clip_model_path or CLIP_CONFIG["default_model_path"],
            device=device or CLIP_CONFIG["default_device"],
            backend=backend,
            max_video_cache=max_video_cache,
            max_text_cache=max_text_cache,
        )
        
        self.default_k = default_k if default_k is not None else TOPK_CONFIG["default_k"]
        
        # Backward-compatible attributes
        self.clip_model_path = self.vision_handler.model_path
        self.device = self.vision_handler.device
        self.max_video_cache = getattr(self.vision_handler, "max_video_cache", None)
        self.max_text_cache = getattr(self.vision_handler, "max_text_cache", None)
    
    # Delegate cache ops to handler (compat layer)
    def _lru_put_video(self, key: str, value: np.ndarray):
        """LRU video cache (delegated)."""
        if hasattr(self.vision_handler, "_lru_put_video"):
            self.vision_handler._lru_put_video(key, value)
    
    def _lru_put_text(self, key: str, value: np.ndarray):
        """LRU text cache (delegated)."""
        if hasattr(self.vision_handler, "_lru_put_text"):
            self.vision_handler._lru_put_text(key, value)
    
    def _encode_images(self, frames: List[Image.Image], video_path: str = None, frame_indices: List[int] = None) -> np.ndarray:
        """Encode image features (delegated)."""
        return self.vision_handler.encode_images(frames, video_path=video_path, frame_indices=frame_indices)
    
    def _encode_text(self, text: str) -> np.ndarray:
        """Encode text features (delegated)."""
        return self.vision_handler.encode_text(text)
    
    def get_or_build_image_feats(self, video_key: str, frames: List[Image.Image], frame_indices: List[int] = None) -> np.ndarray:
        """Get or build image features."""
        return self._encode_images(frames, video_path=video_key, frame_indices=frame_indices)
    
    def encode_texts(self, texts: List[str]) -> np.ndarray:
        """Encode multiple texts."""
        return self.vision_handler.encode_texts(texts)
    
    def _get_image_feats(self, video_key: str, frames: List[Image.Image]) -> np.ndarray:
        """Get image features (legacy-compatible)."""
        return self.get_or_build_image_feats(video_key, frames)
    
    def select_keyframes(self, frames: List[Image.Image], original_indices: List[int], 
                        query: str, num_keyframes: int, video_key: Optional[str] = None, **kwargs) -> List[int]:
        """
        Select Top-K relevant frames.

        Args:
            frames: Frame images.
            original_indices: Original frame indices.
            query: Query text.
            num_keyframes: Number of frames to select.
            video_key: Video identifier (for caching).

        Returns:
            Selected frame indices (subset of original_indices).
        """
        if not frames or num_keyframes <= 0:
            return []
        
        if num_keyframes >= len(frames):
            return original_indices
        
        # Encode image/text features (with caching)
        image_features = self._encode_images(frames, video_key, original_indices)  # (N, D)
        text_features = self._encode_text(query)      # (1, D)
        
        # Similarities
        similarities = (image_features @ text_features.T).squeeze()  # (N,)
        
        # Top-K
        top_k_indices = np.argpartition(similarities, -num_keyframes)[-num_keyframes:]
        top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])[::-1]]
        
        # Map back to original indices
        selected_indices = [original_indices[i] for i in top_k_indices]
        return sorted(selected_indices)
    
    def select_topk(self, frames: List[Image.Image], original_indices: List[int], 
                    query_text: str, k: int, video_key: Optional[str] = None) -> List[int]:
        """Legacy TopK API."""
        return self.select_keyframes(frames, original_indices, query_text, k, video_key)
    
    def score_all(self, frames: List[Image.Image], video_key: str, texts: List[str]) -> Dict[str, np.ndarray]:
        """Compute similarity scores for all texts against all frames."""
        if not frames or not texts:
            return {}
        
        image_features = self._encode_images(frames, video_key)
        text_features = self.encode_texts(texts)
        
        scores = {}
        for i, text in enumerate(texts):
            similarities = (image_features @ text_features[i:i+1].T).squeeze()
            scores[text] = similarities
        
        return scores
    
    def clear_text_cache(self):
        """Clear text cache."""
        self.vision_handler.clear_text_cache()
    
    def clear_image_cache(self, video_key: str = None):
        """Clear image cache (optionally for one video)."""
        self.vision_handler.clear_video_cache(video_key)
        self.vision_handler.clear_frame_cache(video_key)
    
    def set_cache_limits(self, max_video_cache: int = None, max_text_cache: int = None):
        """Set cache limits."""
        self.vision_handler.set_cache_limits(max_video_cache, max_text_cache)
        if max_video_cache is not None:
            self.max_video_cache = max_video_cache
        if max_text_cache is not None:
            self.max_text_cache = max_text_cache
    
    def get_metadata(self):
        """Return sampler metadata."""
        metadata = super().get_metadata()
        metadata.update({
            "description": "VisionText-based TopK frame selection",
            "requires_content": True,
            "requires_query": True,
            "clip_model_path": self.clip_model_path,
            "device": self.device,
            "cache_config": self.vision_handler.get_cache_stats(),
            "handler_info": self.vision_handler.get_model_info()
        })
        return metadata
