"""EEG data sampler for Neural Process training without wbml/neuralprocesses dependencies."""

import numpy as np
import pandas as pd
import torch
from typing import List, Tuple, Optional, Dict, Any, Set
from pathlib import Path
import random

from src.utils import DataAttr


# All EEG subject IDs
_EEG_ALL_SUBJECTS = [
    337, 338, 339, 340, 341, 342, 344, 345, 346, 347, 348, 351, 352, 354, 355,
    356, 357, 359, 362, 363, 364, 365, 367, 368, 369, 370, 371, 372, 373, 374,
    375, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390,
    391, 392, 393, 394, 395, 396, 397, 398, 400, 402, 403, 404, 405, 406, 407,
    409, 410, 411, 412, 414, 415, 416, 417, 418, 419, 421, 422, 423, 424, 425,
    426, 427, 428, 429, 430, 432, 433, 434, 435, 436, 437, 438, 439, 440, 443,
    444, 445, 447, 448, 450, 451, 453, 454, 455, 456, 457, 458, 459, 460, 461,
    1000367,
]

# Fixed buffer size
BUFFER_SIZE = 8

# Predefined nc values to limit compilation (nt will be remainder)
# Total points per trial will be nc + BUFFER_SIZE + nt
# Using powers of 2 and nearby values for computational efficiency
PREDEFINED_NC_VALUES = [
    8, 12, 16, 24, 32, 48, 64, 96, 128, 192, 256,
    # Also include some in-between values for flexibility
    20, 40, 80, 160, 320
]


class EEGSampler:
    """
    EEG data sampler that loads and processes EEG trials for neural process training.
    
    This implementation is standalone and doesn't depend on wbml or neuralprocesses libraries.
    
    Args:
        data_path: Path to EEG data directory
        subset: Data subset - "train", "cv", or "eval"
        mode: Task mode - "interpolation", "forecasting", or "random" (reconstruction currently disabled)
        batch_size: Number of trials per batch
        num_tasks: Total number of tasks to generate per epoch
        total_points: Total number of points per trial (nc + 8 + nt)
        nc_idx: Index into PREDEFINED_NC_VALUES, or None for random
        device: Computation device
        dtype: Tensor data type
        seed: Random seed for reproducibility
    """
    
    def __init__(
        self,
        data_path: Optional[str] = None,
        subset: str = "train",
        mode: str = "random",
        batch_size: int = 16,
        num_tasks: int = 2**14,
        total_points: int = 256,
        nc_idx: Optional[int] = None,
        device: str = "cpu",
        dtype: torch.dtype = torch.float32,
        seed: int = 0,
    ):
        self.data_path = Path(data_path) if data_path else None
        self.subset = subset
        self.mode = mode
        self.batch_size = batch_size
        self.num_tasks = num_tasks
        self.num_batches = num_tasks // batch_size
        self.total_points = total_points
        self.nc_idx = nc_idx
        self.device = device
        self.dtype = dtype
        self.seed = seed
        
        # Set up random state
        self.rng = np.random.RandomState(seed)
        torch.manual_seed(seed)
        
        # EEG channels we're interested in
        self.channels = ["FZ", "F1", "F2", "F3", "F4", "F5", "F6"]
        self.dim_x = 1  # Time dimension
        self.dim_y = len(self.channels)  # Number of channels
        
        # Load trials for this subset
        self.trials = self._load_trials()
        self._trials_i = 0
        
        # Track used combinations for debugging
        self.used_combinations: Set[Tuple[int, int, int]] = set()
        
    def _get_nc_nt(self) -> Tuple[int, int]:
        """Get nc and nt values. nb is always 8."""
        # Filter nc values that would leave room for at least 1 target point
        valid_nc_values = [nc for nc in PREDEFINED_NC_VALUES 
                          if nc + BUFFER_SIZE < self.total_points]
        
        if not valid_nc_values:
            raise ValueError(f"total_points ({self.total_points}) too small for any nc + buffer={BUFFER_SIZE}")
        
        # Select nc
        if self.nc_idx is not None:
            nc = valid_nc_values[self.nc_idx % len(valid_nc_values)]
        else:
            nc = valid_nc_values[self.rng.randint(0, len(valid_nc_values))]
        
        # Calculate nt as remainder
        nt = self.total_points - nc - BUFFER_SIZE
        
        return nc, nt
    
    def _load_trials(self) -> List[Dict[str, Any]]:
        """
        Load EEG trials for the specified subset.
        
        This is a placeholder implementation. In practice, you would:
        1. Load actual EEG data from files (CSV, HDF5, etc.)
        2. Split subjects into train/cv/eval sets
        3. Extract trials with the specified channels
        
        Returns:
            List of trial dictionaries with 'time' and 'data' keys
        """
        # Split subjects based on subset
        subjects = self._get_subset_subjects()
        
        # For now, create synthetic EEG-like data
        trials = []
        for subject_id in subjects:
            # Simulate multiple trials per subject
            num_trials = self.rng.randint(5, 20)
            for trial_idx in range(num_trials):
                # Create synthetic EEG data with enough points
                trial_length = self.rng.randint(
                    self.total_points + 100, 
                    self.total_points + 500
                )
                time = np.linspace(0, 1, trial_length)
                
                # Generate realistic EEG-like signals
                data = {}
                for channel in self.channels:
                    # Combine multiple frequency components
                    signal = np.zeros(trial_length)
                    # Alpha band (8-12 Hz)
                    signal += 0.5 * np.sin(2 * np.pi * 10 * time + self.rng.randn())
                    # Beta band (12-30 Hz) 
                    signal += 0.3 * np.sin(2 * np.pi * 20 * time + self.rng.randn())
                    # Add noise
                    signal += 0.1 * self.rng.randn(trial_length)
                    data[channel] = signal
                
                trials.append({
                    'subject_id': subject_id,
                    'trial_idx': trial_idx,
                    'time': time,
                    'data': pd.DataFrame(data, index=time)
                })
        
        # Shuffle trials
        self.rng.shuffle(trials)
        return trials
    
    def _get_subset_subjects(self) -> List[int]:
        """Get subject IDs for the specified subset."""
        # Use fixed seed for consistent train/cv/eval splits
        rng = np.random.RandomState(99)
        shuffled_subjects = _EEG_ALL_SUBJECTS.copy()
        rng.shuffle(shuffled_subjects)
        
        # Split subjects: 10 eval, 10 cv, rest train
        if self.subset == "eval":
            return shuffled_subjects[:10]
        elif self.subset == "cv":
            return shuffled_subjects[10:20]
        elif self.subset == "train":
            return shuffled_subjects[20:]
        else:
            raise ValueError(f"Unknown subset: {self.subset}")
    
    def generate_batch(self) -> DataAttr:
        """
        Generate a batch of EEG data with context, buffer, and target sets.
        
        Returns:
            DataAttr object with xc, yc, xb, yb, xt, yt attributes
        """
        # Get split sizes for this batch
        nc, nt = self._get_nc_nt()
        nb = BUFFER_SIZE
        self.used_combinations.add((nc, nb, nt))
        
        # Collect batch data
        batch_x = []
        batch_y = []
        
        for _ in range(self.batch_size):
            if self._trials_i >= len(self.trials):
                # Reached end of trials, shuffle and reset
                self.rng.shuffle(self.trials)
                self._trials_i = 0
            
            trial = self.trials[self._trials_i]
            self._trials_i += 1
            
            # Extract exactly total_points from the trial
            trial_length = len(trial['time'])
            if trial_length < self.total_points:
                # Resample if too short
                indices = self.rng.choice(trial_length, self.total_points, replace=True)
            else:
                # Random contiguous segment
                start_idx = self.rng.randint(0, trial_length - self.total_points + 1)
                indices = np.arange(start_idx, start_idx + self.total_points)
            
            # Extract data
            time = trial['time'][indices]
            data = trial['data'].iloc[indices][self.channels].values  # (T, channels)
            
            batch_x.append(torch.tensor(time, dtype=self.dtype, device=self.device))
            batch_y.append(torch.tensor(data.T, dtype=self.dtype, device=self.device))  # (channels, T)
        
        # Stack batch: x is (B, T), y is (B, channels, T)
        x = torch.stack(batch_x).unsqueeze(1)  # (B, 1, T)
        y = torch.stack(batch_y)  # (B, channels, T)
        
        # Apply task
        batch = self._apply_task(x, y, nc, nb, nt)
        
        return batch
    
    def _apply_task(
        self, 
        x: torch.Tensor, 
        y: torch.Tensor,
        nc: int,
        nb: int,
        nt: int
    ) -> DataAttr:
        """
        Apply the specified task to create context, buffer, and target sets.
        All operations are vectorized - no loops over batch dimension.
        
        Args:
            x: Time points tensor (B, 1, T) where T = nc + nb + nt
            y: EEG data tensor (B, channels, T)
            nc: Number of context points
            nb: Number of buffer points (always 8)
            nt: Number of target points
            
        Returns:
            DataAttr with context, buffer, and target sets
        """
        B, C, T = y.shape
        
        # Determine task mode
        mode = self.mode
        if mode == "random":
            # Reconstruction commented out for now due to dimension mismatch concerns
            # mode = random.choice(["interpolation", "forecasting", "reconstruction"])
            mode = random.choice(["interpolation", "forecasting"])
        
        batch = DataAttr()
        
        if mode == "interpolation":
            # Random interpolation: shuffle all indices
            # Create random permutations for each batch element
            indices = torch.stack([torch.randperm(T, device=x.device) for _ in range(B)])
            
            # Split indices
            context_indices = indices[:, :nc]
            buffer_indices = indices[:, nc:nc+nb]
            target_indices = indices[:, nc+nb:]
            
            # Gather data using advanced indexing
            # For x: expand to match y's channel dimension
            x_expanded = x.expand(-1, C, -1)  # (B, C, T)
            
            # Context
            batch.xc = torch.gather(x_expanded, 2, context_indices.unsqueeze(1).expand(-1, C, -1))
            batch.xc = batch.xc[:, 0:1, :].transpose(1, 2)  # (B, nc, 1)
            batch.yc = torch.gather(y, 2, context_indices.unsqueeze(1).expand(-1, C, -1))
            batch.yc = batch.yc.transpose(1, 2)  # (B, nc, C)
            
            # Buffer (same structure as target - all channels)
            batch.xb = torch.gather(x_expanded, 2, buffer_indices.unsqueeze(1).expand(-1, C, -1))
            batch.xb = batch.xb[:, 0:1, :].transpose(1, 2)  # (B, nb, 1)
            batch.yb = torch.gather(y, 2, buffer_indices.unsqueeze(1).expand(-1, C, -1))
            batch.yb = batch.yb.transpose(1, 2)  # (B, nb, C)
            
            # Target
            batch.xt = torch.gather(x_expanded, 2, target_indices.unsqueeze(1).expand(-1, C, -1))
            batch.xt = batch.xt[:, 0:1, :].transpose(1, 2)  # (B, nt, 1)
            batch.yt = torch.gather(y, 2, target_indices.unsqueeze(1).expand(-1, C, -1))
            batch.yt = batch.yt.transpose(1, 2)  # (B, nt, C)
            
        elif mode == "forecasting":
            # Sequential split - buffer is first nb points of the future
            # Context: first nc points
            batch.xc = x[:, :, :nc].transpose(1, 2)  # (B, nc, 1)
            batch.yc = y[:, :, :nc].transpose(1, 2)  # (B, nc, C)
            
            # Buffer: next nb points (first part of future)
            batch.xb = x[:, :, nc:nc+nb].transpose(1, 2)  # (B, nb, 1)
            batch.yb = y[:, :, nc:nc+nb].transpose(1, 2)  # (B, nb, C)
            
            # Target: remaining points
            batch.xt = x[:, :, nc+nb:].transpose(1, 2)  # (B, nt, 1)
            batch.yt = y[:, :, nc+nb:].transpose(1, 2)  # (B, nt, C)
            
        # Reconstruction task commented out for now due to dimension compatibility concerns
        # elif mode == "reconstruction":
        #     # Reconstruction: predict one channel from a certain time onwards
        #     target_channel = self.rng.randint(0, C)
        #     
        #     # Context: first nc points (all channels)
        #     batch.xc = x[:, :, :nc].transpose(1, 2)  # (B, nc, 1)
        #     batch.yc = y[:, :, :nc].transpose(1, 2)  # (B, nc, C)
        #     
        #     # Buffer: next nb points (only target channel - same as what we predict)
        #     batch.xb = x[:, :, nc:nc+nb].transpose(1, 2)  # (B, nb, 1)
        #     batch.yb = y[:, target_channel:target_channel+1, nc:nc+nb].transpose(1, 2)  # (B, nb, 1)
        #     
        #     # Target: remaining points (only target channel)
        #     batch.xt = x[:, :, nc+nb:].transpose(1, 2)  # (B, nt, 1)
        #     batch.yt = y[:, target_channel:target_channel+1, nc+nb:].transpose(1, 2)  # (B, nt, 1)
        
        # Add mask for compatibility
        batch.mask = torch.ones(B, nt, dtype=torch.int8, device=x.device)
        
        return batch
    
    def get_used_combinations(self) -> List[Tuple[int, int, int]]:
        """Return the set of (nc, nb, nt) combinations used so far."""
        return sorted(list(self.used_combinations))
    
    def __iter__(self):
        """Make the sampler iterable."""
        for _ in range(self.num_batches):
            yield self.generate_batch()
    
    def __len__(self):
        """Return number of batches per epoch."""
        return self.num_batches