"""
Active learning dataset management for GLEAM-AI.

This module contains classes for managing active learning datasets,
including splitting data into training and pool sets.
"""

import torch.utils.data as data
import numpy as np
from typing import List, Optional, Union
import torch


class ActiveLearningData:
    """
    Manages active learning dataset splits.
    
    This class splits a dataset into an active training dataset and an
    available pool dataset, allowing for dynamic sample acquisition.
    """
    
    def __init__(self, dataset: data.Dataset):
        """
        Initialize active learning data manager.
        
        Args:
            dataset: The complete dataset to split
        """
        super().__init__()
        
        if not isinstance(dataset, data.Dataset):
            raise TypeError("dataset must be a torch.utils.data.Dataset")
        
        self.dataset = dataset
        self.training_mask = np.full((len(dataset),), False, dtype=bool)
        self.pool_mask = np.full((len(dataset),), True, dtype=bool)
        
        # Create subset datasets
        self.training_dataset = data.Subset(self.dataset, [])
        self.pool_dataset = data.Subset(self.dataset, [])
        
        self._update_indices()
    
    @property
    def train_size(self) -> int:
        """Get the number of samples in the training set."""
        return np.count_nonzero(self.training_mask)
    
    @property
    def pool_size(self) -> int:
        """Get the number of samples in the pool set."""
        return np.count_nonzero(self.pool_mask)
    
    @property
    def total_size(self) -> int:
        """Get the total number of samples."""
        return len(self.dataset)
    
    def _update_indices(self) -> None:
        """Update the indices for training and pool datasets."""
        self.training_dataset.indices = np.nonzero(self.training_mask)[0].tolist()
        self.pool_dataset.indices = np.nonzero(self.pool_mask)[0].tolist()
    
    def get_dataset_indices(self, pool_indices: List[int]) -> List[int]:
        """
        Transform pool indices to original dataset indices.
        
        Args:
            pool_indices: Indices within the pool dataset
            
        Returns:
            Corresponding indices in the original dataset
        """
        if not self.pool_dataset.indices:
            return []
        
        pool_indices = np.array(pool_indices)
        if np.any(pool_indices >= len(self.pool_dataset.indices)):
            raise IndexError("Pool indices out of range")
        
        return [self.pool_dataset.indices[i] for i in pool_indices]
    
    def acquire(self, pool_indices: Union[List[int], np.ndarray]) -> None:
        """
        Acquire samples from the pool dataset into the training dataset.
        
        Args:
            pool_indices: Indices of samples to acquire from the pool
        """
        if isinstance(pool_indices, np.ndarray):
            pool_indices = pool_indices.tolist()
        
        if not pool_indices:
            return
        
        # Convert pool indices to dataset indices
        dataset_indices = self.get_dataset_indices(pool_indices)
        
        # Update masks
        self.training_mask[dataset_indices] = True
        self.pool_mask[dataset_indices] = False
        
        # Update subset indices
        self._update_indices()
    
    def acquire_random(self, num_samples: int) -> List[int]:
        """
        Randomly acquire samples from the pool.
        
        Args:
            num_samples: Number of samples to acquire
            
        Returns:
            Indices of acquired samples in the original dataset
        """
        if num_samples <= 0:
            return []
        
        if num_samples > self.pool_size:
            num_samples = self.pool_size
        
        # Get random pool indices
        pool_indices = np.random.choice(
            len(self.pool_dataset.indices), 
            size=num_samples, 
            replace=False
        ).tolist()
        
        # Acquire the samples
        self.acquire(pool_indices)
        
        # Return original dataset indices
        return self.get_dataset_indices(pool_indices)
    
    def reset(self) -> None:
        """Reset all samples to the pool."""
        self.training_mask.fill(False)
        self.pool_mask.fill(True)
        self._update_indices()
    
    def get_training_indices(self) -> List[int]:
        """Get indices of all training samples."""
        return self.training_dataset.indices.copy()
    
    def get_pool_indices(self) -> List[int]:
        """Get indices of all pool samples."""
        return self.pool_dataset.indices.copy()
    
    def get_stats(self) -> dict:
        """Get statistics about the dataset splits."""
        return {
            "total_samples": self.total_size,
            "training_samples": self.train_size,
            "pool_samples": self.pool_size,
            "training_ratio": self.train_size / self.total_size if self.total_size > 0 else 0,
            "pool_ratio": self.pool_size / self.total_size if self.total_size > 0 else 0
        }
    
    def __len__(self) -> int:
        """Get total number of samples."""
        return self.total_size
    
    def __repr__(self) -> str:
        """String representation of the active learning data."""
        stats = self.get_stats()
        return (f"ActiveLearningData(total={stats['total_samples']}, "
                f"train={stats['training_samples']}, "
                f"pool={stats['pool_samples']})")


class MySubsetRandomSampler(data.Sampler):
    """
    Custom sampler for active learning that works with subset indices.
    
    This sampler allows for random sampling from a subset of the original dataset
    while maintaining the correct mapping to original indices.
    """
    
    def __init__(self, subset_indices: List[int], original_indices: List[int]):
        """
        Initialize the custom sampler.
        
        Args:
            subset_indices: Indices within the subset to sample from
            original_indices: Original dataset indices corresponding to the subset
        """
        self.subset_indices = subset_indices
        self.original_indices = original_indices
        
        if len(subset_indices) != len(original_indices):
            raise ValueError("subset_indices and original_indices must have the same length")
    
    def __iter__(self):
        """Generate random indices."""
        return iter(np.random.permutation(self.subset_indices))
    
    def __len__(self):
        """Get number of samples."""
        return len(self.subset_indices)


def pool_collate_fn(batch):
    """
    Custom collate function for pool data loading.
    
    This function handles batching of pool data and ensures proper
    index tracking for active learning.
    
    Args:
        batch: List of samples from the dataset
        
    Returns:
        Batched data with indices
    """
    # Extract data and indices
    data_items = []
    indices = []
    
    for i, item in enumerate(batch):
        if isinstance(item, (list, tuple)) and len(item) >= 3:
            # Assume format: (x, xt, y0, ...)
            data_items.append(item[:3])  # Take first 3 elements
            indices.append(i)
        else:
            # Fallback: treat as single item
            data_items.append(item)
            indices.append(i)
    
    # Stack data
    if data_items:
        if isinstance(data_items[0], (list, tuple)):
            # Multiple tensors per sample
            stacked_data = []
            for i in range(len(data_items[0])):
                stacked_data.append(torch.stack([item[i] for item in data_items]))
            stacked_data.append(torch.tensor(indices))
            return tuple(stacked_data)
        else:
            # Single tensor per sample
            return torch.stack(data_items), torch.tensor(indices)
    
    return batch, torch.tensor(indices)
