# core/models/clip_handler.py
"""
CLIP model handler (optimized).

Responsible for CLIP model loading and feature extraction.
Supports a global singleton instance and enhanced caching.
"""

import torch
import numpy as np
from PIL import Image
from typing import List, Dict, Any, Optional
from collections import OrderedDict
from transformers import CLIPProcessor, CLIPModel
import threading
import hashlib
import os

from config.settings import CLIP_CONFIG


def _resolve_device(device: Optional[str]) -> str:
    if not device or device == "auto":
        return "cuda" if torch.cuda.is_available() else "cpu"
    return device


class CLIPModelHandler:
    """CLIP model handler (singleton)."""
    
    _instances = {}
    _lock = threading.Lock()
    
    def __new__(cls, model_path: str = None, device: str = None, **kwargs):
        """Singleton: same model_path + device returns the same instance."""
        model_path = model_path or CLIP_CONFIG["default_model_path"]
        device = _resolve_device(device or CLIP_CONFIG["default_device"])
        
        key = f"{model_path}#{device}"
        
        if key not in cls._instances:
            with cls._lock:
                if key not in cls._instances:
                    instance = super().__new__(cls)
                    cls._instances[key] = instance
        
        return cls._instances[key]
    
    def __init__(self, model_path: str = None, device: str = None, 
                 max_video_cache: int = None, max_text_cache: int = None, **kwargs):
        # Avoid double initialization
        if hasattr(self, '_initialized'):
            return
            
        # Use config defaults
        self.model_path = model_path or CLIP_CONFIG["default_model_path"]
        self.device = _resolve_device(device or CLIP_CONFIG["default_device"])
        self.max_video_cache = max_video_cache if max_video_cache is not None else CLIP_CONFIG["max_video_cache"]
        self.max_text_cache = max_text_cache if max_text_cache is not None else CLIP_CONFIG["max_text_cache"]
        
        self.model = None
        self.processor = None
        self.config = kwargs
        
        # Enhanced caching
        self.video_cache = OrderedDict()  # video_key -> features
        self.text_cache = OrderedDict()   # text -> features
        self.frame_cache = OrderedDict()  # frame_key -> frame_features
        
        # Cache stats
        self.cache_hits = {"text": 0, "video": 0, "frame": 0}
        self.cache_misses = {"text": 0, "video": 0, "frame": 0}
        
        self._initialized = True
    
    def load_model(self):
        """Load the CLIP model and processor."""
        if self.model is None or self.processor is None:
            print(f"[CLIP] Loading model from {self.model_path}...")
            self.model = CLIPModel.from_pretrained(self.model_path).to(self.device)
            self.processor = CLIPProcessor.from_pretrained(self.model_path)
            print(f"[CLIP] Model loaded successfully")
        
        return self.model, self.processor
    
    def _lru_put_video(self, key: str, value: np.ndarray):
        """LRU cache for video-level entries."""
        if self.max_video_cache <= 0:
            return
        if key in self.video_cache:
            self.video_cache.move_to_end(key)
        else:
            self.video_cache[key] = value
            if len(self.video_cache) > self.max_video_cache:
                self.video_cache.popitem(last=False)
    
    def _lru_put_text(self, key: str, value: np.ndarray):
        """LRU cache for text entries."""
        if self.max_text_cache <= 0:
            return
        if key in self.text_cache:
            self.text_cache.move_to_end(key)
        else:
            self.text_cache[key] = value
            if len(self.text_cache) > self.max_text_cache:
                self.text_cache.popitem(last=False)
    
    def _lru_put_frame(self, key: str, value: np.ndarray):
        """LRU cache for per-frame entries."""
        if self.max_video_cache <= 0:  # Use the video cache limit
            return
        if key in self.frame_cache:
            self.frame_cache.move_to_end(key)
        else:
            self.frame_cache[key] = value
            # Frame cache uses a larger capacity.
            max_frame_cache = self.max_video_cache * 50  # Up to 50 frames per video
            if len(self.frame_cache) > max_frame_cache:
                self.frame_cache.popitem(last=False)
    
    def _get_frame_key(self, video_path: str, frame_idx: int) -> str:
        """Build a frame cache key."""
        video_hash = hashlib.md5(video_path.encode()).hexdigest()[:8]
        return f"{video_hash}_{frame_idx}"
    
    @torch.no_grad()
    def encode_images(self, frames: List[Image.Image], batch_size: int = 64, 
                     video_path: str = None, frame_indices: List[int] = None) -> np.ndarray:
        """
        Encode image features (supports frame-level caching).

        Args:
            frames: List of PIL images.
            batch_size: Batch size.
            video_path: Optional video path (for caching).
            frame_indices: Optional frame indices (for caching).

        Returns:
            np.ndarray: Image feature matrix (N, D).
        """
        if not frames:
            return np.array([])
        
        self.load_model()  # Ensure the model is loaded
        
        # Try frame-level caching
        if video_path and frame_indices and len(frames) == len(frame_indices):
            cached_features = []
            uncached_frames = []
            uncached_indices = []
            
            for i, (frame, frame_idx) in enumerate(zip(frames, frame_indices)):
                frame_key = self._get_frame_key(video_path, frame_idx)
                if frame_key in self.frame_cache:
                    self.frame_cache.move_to_end(frame_key)
                    cached_features.append((i, self.frame_cache[frame_key]))
                    self.cache_hits["frame"] += 1
                else:
                    uncached_frames.append((i, frame))
                    uncached_indices.append(frame_idx)
                    self.cache_misses["frame"] += 1
            
            # Process uncached frames
            new_features = []
            if uncached_frames:
                batch_frames = [frame for _, frame in uncached_frames]
                for i in range(0, len(batch_frames), batch_size):
                    batch = batch_frames[i:i+batch_size]
                    inputs = self.processor(images=batch, return_tensors="pt", padding=True).to(self.device)
                    image_features = self.model.get_image_features(**inputs)
                    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                    features = image_features.cpu().numpy()
                    
                    # Cache per-frame features
                    for j, feature in enumerate(features):
                        global_idx = i + j
                        if global_idx < len(uncached_frames):
                            original_idx, _ = uncached_frames[global_idx]
                            frame_idx = uncached_indices[global_idx]
                            frame_key = self._get_frame_key(video_path, frame_idx)
                            self._lru_put_frame(frame_key, feature)
                            new_features.append((original_idx, feature))
            
            # Merge cached and newly computed features
            all_indexed_features = cached_features + new_features
            all_indexed_features.sort(key=lambda x: x[0])  # Keep original order
            return np.array([feature for _, feature in all_indexed_features])
        
        # Standard batch encoding
        all_features = []
        for i in range(0, len(frames), batch_size):
            batch_frames = frames[i:i+batch_size]
            inputs = self.processor(images=batch_frames, return_tensors="pt", padding=True).to(self.device)
            image_features = self.model.get_image_features(**inputs)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            all_features.append(image_features.cpu().numpy())
        
        return np.vstack(all_features) if all_features else np.array([])
    
    @torch.no_grad()
    def encode_text(self, text: str) -> np.ndarray:
        """
        Encode text features.

        Args:
            text: Input text.

        Returns:
            np.ndarray: Text feature vector (1, D).
        """
        # Cache lookup
        if text in self.text_cache:
            self.text_cache.move_to_end(text)
            self.cache_hits["text"] += 1
            return self.text_cache[text]
        
        self.cache_misses["text"] += 1
        self.load_model()  # Ensure the model is loaded
        
        inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(self.device)
        text_features = self.model.get_text_features(**inputs)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        features = text_features.cpu().numpy()
        
        # Cache
        self._lru_put_text(text, features)
        return features
    
    @torch.no_grad()
    def encode_texts(self, texts: List[str]) -> np.ndarray:
        """
        Encode multiple texts.

        Args:
            texts: List of input texts.

        Returns:
            np.ndarray: Text feature matrix (N, D).
        """
        if not texts:
            return np.array([])

        # Keep input order; batch-encode the uncached texts.
        out: List[Optional[np.ndarray]] = [None] * len(texts)
        to_compute: List[str] = []
        text_to_indices: Dict[str, List[int]] = {}

        for i, text in enumerate(texts):
            if text in self.text_cache:
                self.text_cache.move_to_end(text)
                self.cache_hits["text"] += 1
                out[i] = self.text_cache[text]
                continue

            self.cache_misses["text"] += 1
            if text not in text_to_indices:
                to_compute.append(text)
                text_to_indices[text] = [i]
            else:
                text_to_indices[text].append(i)

        if to_compute:
            self.load_model()
            inputs = self.processor(text=to_compute, return_tensors="pt", padding=True, truncation=True).to(self.device)
            text_features = self.model.get_text_features(**inputs)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            feats_np = text_features.detach().cpu().numpy()  # (M, D)

            for j, text in enumerate(to_compute):
                feat = feats_np[j : j + 1]  # (1, D), aligned with encode_text
                self._lru_put_text(text, feat)
                for idx in text_to_indices.get(text, []):
                    out[idx] = feat

        # Defensive: if any entry is None, fall back to per-item encoding (should not happen).
        for i, v in enumerate(out):
            if v is None:
                out[i] = self.encode_text(texts[i])

        return np.vstack(out)  # type: ignore[arg-type]
    
    def compute_similarity(self, image_features: np.ndarray, text_features: np.ndarray) -> np.ndarray:
        """
        Compute similarity between image and text features.

        Args:
            image_features: Image feature matrix (N, D).
            text_features: Text feature matrix (M, D).

        Returns:
            np.ndarray: Similarity matrix (N, M).
        """
        return image_features @ text_features.T
    
    def clear_text_cache(self):
        """Clear the text cache."""
        self.text_cache.clear()
    
    def clear_video_cache(self, video_key: str = None):
        """Clear the video cache."""
        if video_key is None:
            self.video_cache.clear()
        else:
            self.video_cache.pop(video_key, None)
    
    def clear_frame_cache(self, video_path: str = None):
        """Clear the frame cache."""
        if video_path is None:
            self.frame_cache.clear()
        else:
            video_hash = hashlib.md5(video_path.encode()).hexdigest()[:8]
            keys_to_remove = [k for k in self.frame_cache.keys() if k.startswith(video_hash)]
            for key in keys_to_remove:
                del self.frame_cache[key]
    
    def get_cache_stats(self) -> Dict[str, Any]:
        """Return cache statistics."""
        return {
            "cache_sizes": {
                "text": len(self.text_cache),
                "video": len(self.video_cache),
                "frame": len(self.frame_cache)
            },
            "cache_hits": self.cache_hits.copy(),
            "cache_misses": self.cache_misses.copy(),
            "hit_rates": {
                cache_type: hits / max(1, hits + self.cache_misses[cache_type]) 
                for cache_type, hits in self.cache_hits.items()
            }
        }
    
    def set_cache_limits(self, max_video_cache: int = None, max_text_cache: int = None):
        """Set cache size limits."""
        if max_video_cache is not None:
            self.max_video_cache = max_video_cache
        if max_text_cache is not None:
            self.max_text_cache = max_text_cache
    
    def get_model_info(self) -> Dict[str, Any]:
        """Return model info."""
        return {
            "model_path": self.model_path,
            "device": self.device,
            "cache_config": {
                "max_video_cache": self.max_video_cache,
                "max_text_cache": self.max_text_cache,
                "current_video_cache_size": len(self.video_cache),
                "current_text_cache_size": len(self.text_cache),
                "current_frame_cache_size": len(self.frame_cache)
            },
            "cache_stats": self.get_cache_stats(),
            "config": self.config
        }
