# core/samplers/aks.py
"""
AKS (Adaptive Key-frame Sampling) sampler.
Implements an adaptive keyframe sampling algorithm.
"""

import numpy as np
import heapq
from PIL import Image
from typing import List

from .base import FrameBasedSampler
from core.models.vision_factory import create_vision_text_handler
from config.settings import CLIP_CONFIG, AKS_CONFIG


class AKSSampler(FrameBasedSampler):
    """
    AKS (Adaptive Key-frame Sampling) sampler.
    Adaptive keyframe sampling based on vision-text similarity scores.
    """
    
    def __init__(self, clip_model_path: str = None, device: str = None, **kwargs):
        super().__init__(**kwargs)
        
        self.clip_model_path = clip_model_path or CLIP_CONFIG["default_model_path"]
        self.device = device or CLIP_CONFIG["default_device"]
        
        # AKS parameters
        self.t1 = kwargs.get('t1', AKS_CONFIG["t1"])
        self.t2 = kwargs.get('t2', AKS_CONFIG["t2"])
        self.max_depth = kwargs.get('max_depth', AKS_CONFIG["max_depth"])
        self.batch_size = kwargs.get('batch_size', AKS_CONFIG["batch_size"])
        
        backend = kwargs.get("vision_backend") or kwargs.get("backend")
        self.vision_handler = create_vision_text_handler(
            model_path=self.clip_model_path,
            device=self.device,
            backend=backend,
        )
        # Normalize device field (auto -> cuda/cpu)
        self.device = self.vision_handler.device
        print(f"AKS vision backbone: {self.vision_handler.get_model_info().get('backend', 'unknown')} @ {self.vision_handler.model_path}")
    
    def _get_relevance_scores(self, frames: List[Image.Image], question: str) -> List[float]:
        """
        Compute relevance scores between all frames and the question.
        """
        print("AKS: Step 1/2 - computing relevance scores...")
        
        text_features = self.vision_handler.encode_text(question)  # (1, D) numpy
        image_features = self.vision_handler.encode_images(frames, batch_size=self.batch_size)  # (N, D) numpy
        sim = (image_features @ text_features.T).reshape(-1)
        scores = sim.tolist()
        
        print("AKS: Relevance scoring complete.")
        return scores
    
    def _meanstd_recursive(self, segments_data: List[dict], num_total_frames: int, 
                          num_keyframes: int, t1: float, t2: float, max_depth: int) -> List[dict]:
        """
        Recursive segmentation algorithm (AKS core logic).
        """
        split_segments, no_split_segments = [], []
        
        for seg in segments_data:
            scores = seg['scores']
            depth = seg['depth']
            
            # Stop splitting if segment too small or depth too large
            if len(scores) < 2 or depth >= max_depth:
                no_split_segments.append(seg)
                continue
            
            # Statistics
            mean = np.mean(scores)
            std = np.std(scores)
            top_n_count = min(len(scores), num_keyframes)
            top_scores = heapq.nlargest(top_n_count, scores)
            mean_diff = np.mean(top_scores) - mean
            
            # Split decision
            if mean_diff > t1 and std > t2:
                no_split_segments.append(seg)
            else:
                # Split in half
                mid_point = len(scores) // 2
                split_segments.append({
                    'scores': seg['scores'][:mid_point], 
                    'indices': seg['indices'][:mid_point], 
                    'depth': depth + 1
                })
                split_segments.append({
                    'scores': seg['scores'][mid_point:], 
                    'indices': seg['indices'][mid_point:], 
                    'depth': depth + 1
                })
        
        final_segments = no_split_segments
        if split_segments:
            final_segments += self._meanstd_recursive(split_segments, num_total_frames, 
                                                    num_keyframes, t1, t2, max_depth)
        
        return final_segments
    
    def _adaptive_sampling(self, scores: List[float], frame_indices: List[int], 
                          num_keyframes: int) -> List[int]:
        """
        Run adaptive sampling given relevance scores.
        """
        print("AKS: Step 2/2 - running adaptive sampling...")
        
        # If budget exceeds frames, return all
        if len(scores) <= num_keyframes:
            return frame_indices
        
        # Normalize scores
        normalized_data = (np.array(scores) - np.min(scores)) / (np.max(scores) - np.min(scores) + 1e-6)
        
        # Run recursive segmentation
        initial_segment = [{
            'scores': normalized_data.tolist(), 
            'indices': frame_indices, 
            'depth': 0
        }]
        final_segments = self._meanstd_recursive(initial_segment, len(scores), num_keyframes, 
                                               self.t1, self.t2, self.max_depth)
        
        # Select frames from final segments
        selected_scores = {}
        selected_indices = []
        for seg in final_segments:
            seg_scores = seg['scores']
            seg_indices = seg['indices']
            depth = seg['depth']
            
            if not seg_scores:
                continue
            
            # Depth-based per-segment budget
            num_to_select = max(1, int(num_keyframes / (2 ** depth)))
            num_to_select = min(num_to_select, len(seg_scores))
            
            # Pick highest-score frames
            top_k_local_indices = heapq.nlargest(
                num_to_select, range(len(seg_scores)), seg_scores.__getitem__
            )
            for local_idx in top_k_local_indices:
                global_idx = seg_indices[local_idx]
                score = seg_scores[local_idx]
                # Keep best score per frame for later global pruning
                if global_idx not in selected_scores or score > selected_scores[global_idx]:
                    selected_scores[global_idx] = score
        
        selected_indices = list(selected_scores.keys())
        if len(selected_indices) > num_keyframes:
            # Prune by score to respect budget
            sorted_by_score = sorted(
                selected_scores.items(), key=lambda item: item[1], reverse=True
            )
            selected_indices = [idx for idx, _ in sorted_by_score[:num_keyframes]]
        
        # Return sorted indices
        selected_indices = sorted(selected_indices)
        
        print(f"AKS: Adaptive sampling complete; selected {len(selected_indices)} frames.")
        return selected_indices
    
    def select_keyframes(self, frames: List[Image.Image], original_indices: List[int], 
                        query: str, num_keyframes: int, **kwargs) -> List[int]:
        """
        Run the full AKS pipeline.

        Args:
            frames: Frame images (typically extracted at 1fps).
            original_indices: Original frame indices.
            query: Query text.
            num_keyframes: Number of frames to select.

        Returns:
            Selected frame indices (subset of original_indices).
        """
        if not frames:
            return []
        
        # Stage 1: relevance scores
        scores = self._get_relevance_scores(frames, query)
        
        # Stage 2: adaptive sampling
        selected_local_indices = self._adaptive_sampling(scores, list(range(len(scores))), num_keyframes)
        selected_global_indices = [original_indices[i] for i in selected_local_indices]
        
        return sorted(selected_global_indices)
    
    def get_metadata(self):
        """Return sampler metadata."""
        metadata = super().get_metadata()
        metadata.update({
            "description": "Adaptive Key-frame Sampling using CLIP",
            "requires_content": True,
            "requires_query": True,
            "clip_model_path": self.clip_model_path,
            "device": self.device,
            "algorithm_params": {
                "t1": self.t1,
                "t2": self.t2,
                "max_depth": self.max_depth,
                "batch_size": self.batch_size
            }
        })
        return metadata
