"""EEG DataLoader wrapper for compatibility with existing training infrastructure."""

from typing import Tuple, Optional
import torch

from src.data.eeg_sampler import EEGSampler
from src.utils import DataAttr


class EEGDataLoader:
    """
    EEG data loader that wraps EEGSampler for compatibility with existing training code.
    
    This loader provides the same interface as GaussianProcessDataLoader but uses
    EEG data instead of synthetic GP samples.
    
    Args:
        subset: Data subset - "train", "cv", or "eval"
        mode: Task mode - "interpolation", "forecasting", or "random"
        total_points: Total number of points per trial
        device: Computation device
        seed: Random seed
        data_path: Optional path to EEG data files
    """
    
    def __init__(
        self,
        subset: str = "train",
        mode: str = "random",
        total_points: int = 256,
        device: str = "cpu",
        seed: int = 0,
        data_path: Optional[str] = None,
        **kwargs  # Catch any additional args from config
    ):
        self.subset = subset
        self.mode = mode
        self.total_points = total_points
        self.device = device
        self.seed = seed
        self.data_path = data_path
        
    def __call__(
        self,
        problem,  # Not used for EEG, but kept for interface compatibility
        batch_size: int,
        num_register_points: int,  # Maps to nc
        num_latent: int,  # Not used for EEG
        min_register_points: int,
        max_register_points: int,
        x_range: Tuple[float, float],  # Not used for EEG (time is always 0-1)
        max_buffer_points: int,  # Should be 8 for EEG
        num_target_partitions: int,  # Not used for EEG
        num_target_data_per_partition: int,  # Not used for EEG
    ) -> Tuple[DataAttr, DataAttr]:
        """
        Generate EEG data batch compatible with existing training code.
        
        Returns:
            Tuple of (context_target_batch, buffer_batch)
        """
        # Validate buffer size
        if max_buffer_points != 8:
            print(f"Warning: EEG expects buffer size 8, got {max_buffer_points}")
        
        # Determine nc (context size)
        if num_register_points == "random":
            nc_idx = None  # Will randomly select from PREDEFINED_NC_VALUES
        else:
            # Find closest valid nc value
            from src.data.eeg_sampler import PREDEFINED_NC_VALUES
            valid_nc = [nc for nc in PREDEFINED_NC_VALUES 
                       if nc >= min_register_points and nc <= max_register_points]
            if not valid_nc:
                # Fallback to any valid nc
                nc_idx = None
            else:
                # Use the specified value if it's valid, otherwise closest
                if num_register_points in valid_nc:
                    nc_idx = PREDEFINED_NC_VALUES.index(num_register_points)
                else:
                    # Find closest
                    closest_nc = min(valid_nc, key=lambda x: abs(x - num_register_points))
                    nc_idx = PREDEFINED_NC_VALUES.index(closest_nc)
        
        # Create sampler for single batch
        sampler = EEGSampler(
            data_path=self.data_path,
            subset=self.subset,
            mode=self.mode,
            batch_size=batch_size,
            num_tasks=batch_size,  # Just one batch
            total_points=self.total_points,
            nc_idx=nc_idx,
            device=self.device,
            dtype=torch.float32,
            seed=self.seed,
        )
        
        # Generate batch
        eeg_batch = sampler.generate_batch()
        
        # Convert to expected format
        # Context-target batch combines context and target
        context_target_batch = DataAttr()
        context_target_batch.xc = eeg_batch.xc
        context_target_batch.yc = eeg_batch.yc
        context_target_batch.xt = eeg_batch.xt
        context_target_batch.yt = eeg_batch.yt
        context_target_batch.loss_mask = eeg_batch.mask
        
        # Buffer batch
        buffer_batch = DataAttr()
        buffer_batch.xc = eeg_batch.xb  # Buffer data as "context" for this batch
        buffer_batch.yc = eeg_batch.yb
        buffer_batch.xt = None
        buffer_batch.yt = None
        buffer_batch.mask = None
        
        return context_target_batch, buffer_batch
    
    def create_generator(
        self,
        batch_size: int,
        num_batches: int,
        num_register_points: int = "random",
        min_register_points: int = 8,
        max_register_points: int = 128,
        **kwargs
    ):
        """
        Create an EEG data generator for training.
        
        Returns:
            EEGSampler instance that can be iterated
        """
        # Determine nc_idx
        if num_register_points == "random":
            nc_idx = None
        else:
            from src.data.eeg_sampler import PREDEFINED_NC_VALUES
            if num_register_points in PREDEFINED_NC_VALUES:
                nc_idx = PREDEFINED_NC_VALUES.index(num_register_points)
            else:
                nc_idx = None
        
        return EEGSampler(
            data_path=self.data_path,
            subset=self.subset,
            mode=self.mode,
            batch_size=batch_size,
            num_tasks=batch_size * num_batches,
            total_points=self.total_points,
            nc_idx=nc_idx,
            device=self.device,
            dtype=torch.float32,
            seed=self.seed,
        )