"""
Mean-Standard Deviation acquisition function for active learning.

This module contains the MeanStd acquisition function that selects samples
based on the standard deviation of model predictions.
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import numpy as np
from typing import List, Dict, Any, Optional

from .base import BaseAcquisition
from ...data.datasets import MySubsetRandomSampler, pool_collate_fn


class MeanStd(BaseAcquisition):
    """
    Mean-Standard Deviation acquisition function.
    
    This acquisition function selects samples based on the standard deviation
    of model predictions, preferring samples where the model is most uncertain.
    """
    
    def __init__(
        self,
        output_dim: int,
        acquisition_size: int,
        pool_loader_batch_size: int,
        acquisition_pool_fraction: float,
        num_workers: int = 4,
        device: str = "cpu"
    ):
        """
        Initialize the MeanStd acquisition function.
        
        Args:
            output_dim: Dimension of model output
            acquisition_size: Number of samples to acquire
            pool_loader_batch_size: Batch size for processing pool
            acquisition_pool_fraction: Fraction of pool to consider
            num_workers: Number of data loading workers
            device: Computation device
        """
        super().__init__(
            acquisition_size=acquisition_size,
            pool_loader_batch_size=pool_loader_batch_size,
            acquisition_pool_fraction=acquisition_pool_fraction,
            num_workers=num_workers,
            device=device
        )
        
        if output_dim <= 0:
            raise ValueError("output_dim must be positive")
        
        self.output_dim = output_dim
    
    @torch.no_grad()
    def get_candidate_batch(
        self, 
        model: nn.Module, 
        active_data: 'ActiveLearningData', 
        **kwargs
    ) -> np.ndarray:
        """
        Get candidate samples using mean-std acquisition.
        
        Args:
            model: The trained model
            active_data: Active learning data manager
            **kwargs: Additional arguments
            
        Returns:
            Array of indices of selected samples
        """
        model.to(self.device)
        model.eval()
        
        # Calculate pool size to consider
        pool_size = int(self.acquisition_pool_fraction * len(active_data.pool_dataset))
        
        # Get random subset of pool
        random_indices = np.random.permutation(len(active_data.pool_dataset))[:pool_size]
        random_indices = random_indices.tolist()
        
        # Create data loader for the subset
        # Check if pool_dataset has indices attribute for active learning
        pool_indices = getattr(active_data.pool_dataset, 'indices', None)
        if pool_indices is not None:
            sampler = MySubsetRandomSampler(random_indices, pool_indices)
        else:
            sampler = MySubsetRandomSampler(random_indices)
        
        pool_loader = DataLoader(
            active_data.pool_dataset,
            shuffle=False,
            batch_size=self.pool_loader_batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            collate_fn=pool_collate_fn,
            pin_memory=self.pin_memory
        )
        
        # Initialize score storage
        scores = torch.zeros(pool_size, dtype=torch.float32, device=self.device)
        indices_array = np.zeros(pool_size, dtype=np.int64)
        
        # Process batches
        for i, batch in enumerate(pool_loader):
            # Unpack batch from pool_collate_fn
            x, xt, y0, indices = batch
            
            # Move to device
            x = x.to(self.device)
            xt = xt.to(self.device)
            y0 = y0.to(self.device)
            
            # Get model predictions
            if hasattr(model, 'predict'):
                yp = model.predict(x, xt, y0)
            else:
                # Fallback to forward pass
                yp, _ = model(x, xt, y0)
            
            # Compute acquisition scores
            score = self.acquisition_fn(yp, dim=1)
            
            # Store scores and indices
            ilow = i * self.pool_loader_batch_size
            ihigh = min(ilow + self.pool_loader_batch_size, pool_size)
            
            scores[ilow:ihigh] = score.cpu()
            indices_array[ilow:ihigh] = indices.cpu().numpy()
        
        # Select best samples
        best_indices = torch.argsort(scores, dim=0, descending=True)[:self.acquisition_size]
        best_indices = best_indices.cpu().numpy()
        
        return best_indices
    
    def acquisition_fn(self, y_pred: torch.Tensor, dim: int = 1) -> torch.Tensor:
        """
        Compute acquisition scores based on prediction standard deviation.
        
        Args:
            y_pred: Model predictions [batch_size, seq_len, output_dim]
            dim: Dimension along which to compute standard deviation
            
        Returns:
            Acquisition scores [batch_size]
        """
        if y_pred.dim() < 3:
            raise ValueError("y_pred must have at least 3 dimensions")
        
        # Compute standard deviation along specified dimension
        std_scores = torch.std(y_pred, dim=dim)
        
        # Take mean across output dimensions
        mean_scores = std_scores.mean(dim=-1)
        
        return mean_scores
    
    def get_config(self) -> Dict[str, Any]:
        """Get configuration including output dimension."""
        config = super().get_config()
        config["output_dim"] = self.output_dim
        return config
