"""
Base acquisition function classes for active learning.

This module contains the base classes and interfaces for acquisition functions
used in the active learning framework.
"""

import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import numpy as np


class BaseAcquisition(ABC):
    """
    Abstract base class for acquisition functions.
    
    All acquisition functions should inherit from this class and implement
    the required methods for selecting samples from the unlabeled pool.
    """
    
    def __init__(
        self,
        acquisition_size: int,
        pool_loader_batch_size: int,
        acquisition_pool_fraction: float,
        num_workers: int = 4,
        device: str = "cpu"
    ):
        """
        Initialize the base acquisition function.
        
        Args:
            acquisition_size: Number of samples to acquire in each iteration
            pool_loader_batch_size: Batch size for processing the pool
            acquisition_pool_fraction: Fraction of pool to consider for acquisition
            num_workers: Number of workers for data loading
            device: Device to run computations on
        """
        if acquisition_size <= 0:
            raise ValueError("acquisition_size must be positive")
        if pool_loader_batch_size <= 0:
            raise ValueError("pool_loader_batch_size must be positive")
        if not (0 < acquisition_pool_fraction <= 1.0):
            raise ValueError("acquisition_pool_fraction must be between 0 and 1")
        if num_workers < 0:
            raise ValueError("num_workers must be non-negative")
        
        self.acquisition_size = acquisition_size
        self.pool_loader_batch_size = pool_loader_batch_size
        self.acquisition_pool_fraction = acquisition_pool_fraction
        self.num_workers = num_workers
        self.device = device
        self.pin_memory = True
    
    @abstractmethod
    def get_candidate_batch(
        self, 
        model: nn.Module, 
        active_data: 'ActiveLearningData', 
        **kwargs
    ) -> np.ndarray:
        """
        Get candidate samples for acquisition.
        
        Args:
            model: The trained model
            active_data: Active learning data manager
            **kwargs: Additional arguments specific to the acquisition function
            
        Returns:
            Array of indices of selected samples
        """
        raise NotImplementedError('get_candidate_batch not implemented')
    
    @abstractmethod
    def acquisition_fn(self, *args, **kwargs) -> torch.Tensor:
        """
        Compute acquisition scores for a batch of samples.
        
        Args:
            *args: Arguments specific to the acquisition function
            **kwargs: Additional keyword arguments
            
        Returns:
            Acquisition scores for the batch
        """
        raise NotImplementedError('acquisition_fn not implemented')
    
    def set_device(self, device: str) -> None:
        """Set the device for computations."""
        self.device = device
    
    def get_config(self) -> Dict[str, Any]:
        """Get configuration parameters."""
        return {
            "acquisition_size": self.acquisition_size,
            "pool_loader_batch_size": self.pool_loader_batch_size,
            "acquisition_pool_fraction": self.acquisition_pool_fraction,
            "num_workers": self.num_workers,
            "device": self.device
        }
    
    def __repr__(self) -> str:
        """String representation of the acquisition function."""
        config = self.get_config()
        return f"{self.__class__.__name__}({', '.join(f'{k}={v}' for k, v in config.items())})"


class AcquisitionConfig:
    """
    Configuration class for acquisition functions.
    
    This class provides a standardized way to configure acquisition functions
    with validation and type safety.
    """
    
    def __init__(
        self,
        acquisition_size: int,
        pool_loader_batch_size: int,
        acquisition_pool_fraction: float,
        num_workers: int = 4,
        device: str = "cpu",
        **kwargs
    ):
        """
        Initialize acquisition configuration.
        
        Args:
            acquisition_size: Number of samples to acquire
            pool_loader_batch_size: Batch size for pool processing
            acquisition_pool_fraction: Fraction of pool to consider
            num_workers: Number of data loading workers
            device: Computation device
            **kwargs: Additional configuration parameters
        """
        self.acquisition_size = acquisition_size
        self.pool_loader_batch_size = pool_loader_batch_size
        self.acquisition_pool_fraction = acquisition_pool_fraction
        self.num_workers = num_workers
        self.device = device
        self.extra_params = kwargs
        
        self._validate()
    
    def _validate(self) -> None:
        """Validate configuration parameters."""
        if self.acquisition_size <= 0:
            raise ValueError("acquisition_size must be positive")
        if self.pool_loader_batch_size <= 0:
            raise ValueError("pool_loader_batch_size must be positive")
        if not (0 < self.acquisition_pool_fraction <= 1.0):
            raise ValueError("acquisition_pool_fraction must be between 0 and 1")
        if self.num_workers < 0:
            raise ValueError("num_workers must be non-negative")
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert configuration to dictionary."""
        return {
            "acquisition_size": self.acquisition_size,
            "pool_loader_batch_size": self.pool_loader_batch_size,
            "acquisition_pool_fraction": self.acquisition_pool_fraction,
            "num_workers": self.num_workers,
            "device": self.device,
            **self.extra_params
        }
