# core/samplers/base.py
"""
Sampler base classes.
All samplers should inherit from these and implement the required methods.
"""

from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from PIL import Image


class BaseSampler(ABC):
    """Base sampler interface."""
    
    def __init__(self, **kwargs):
        """Initialize sampler."""
        self.config = kwargs
    
    @abstractmethod
    def select_keyframes(self, *args, **kwargs) -> List[int]:
        """
        Select keyframes.

        Returns:
            List of selected frame indices.
        """
        pass
    
    def get_sampler_name(self) -> str:
        """Return sampler name."""
        return self.__class__.__name__.lower().replace("sampler", "")
    
    def get_metadata(self) -> Dict[str, Any]:
        """Return sampler metadata."""
        return {
            "sampler_type": self.get_sampler_name(),
            "config": self.config
        }


class VideoBasedSampler(BaseSampler):
    """Sampler that reads video content."""
    
    @abstractmethod
    def select_keyframes(self, video_path: str, num_keyframes: int, **kwargs) -> List[int]:
        """
        Select keyframes from a video.

        Args:
            video_path: Path to video.
            num_keyframes: Number of frames to select.
            **kwargs: Additional parameters.

        Returns:
            List of selected frame indices.
        """
        pass


class QueryBasedSampler(VideoBasedSampler):
    """Sampler that uses a text query."""
    
    @abstractmethod
    def select_keyframes(self, video_path: str, num_keyframes: int, query: str, **kwargs) -> List[int]:
        """
        Select keyframes from a video given a query.

        Args:
            video_path: Path to video.
            num_keyframes: Number of frames to select.
            query: Query text.
            **kwargs: Additional parameters.

        Returns:
            List of selected frame indices.
        """
        pass


class FrameBasedSampler(BaseSampler):
    """Sampler that operates on pre-extracted frames."""
    
    @abstractmethod
    def select_keyframes(self, frames: List[Image.Image], original_indices: List[int], 
                        query: str, num_keyframes: int, **kwargs) -> List[int]:
        """
        Select keyframes from pre-extracted frames.

        Args:
            frames: Frame images.
            original_indices: Original frame indices.
            query: Query text.
            num_keyframes: Number of frames to select.
            **kwargs: Additional parameters.

        Returns:
            Selected indices (subset of original_indices).
        """
        pass


class SimpleFrameSampler(BaseSampler):
    """Sampler that only needs total frame count."""
    
    @abstractmethod
    def select_keyframes(self, total_frames: int, num_keyframes: int, **kwargs) -> List[int]:
        """
        Select keyframes given a total frame count.

        Args:
            total_frames: Total number of frames in the video.
            num_keyframes: Number of frames to select.
            **kwargs: Additional parameters.

        Returns:
            List of selected frame indices.
        """
        pass
