"""EEG data sampler for Neural Process training without wbml/neuralprocesses dependencies."""

import os
import glob
import pickle
import warnings
import multiprocessing
import tarfile

import numpy as np
import pandas as pd
import torch
from typing import List, Tuple, Optional, Dict, Any, Set
from pathlib import Path
import random

from src.utils import DataAttr


# All EEG subject IDs
_EEG_ALL_SUBJECTS = [
    337, 338, 339, 340, 341, 342, 344, 345, 346, 347, 348, 351, 352, 354, 355,
    356, 357, 359, 362, 363, 364, 365, 367, 368, 369, 370, 371, 372, 373, 374,
    375, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390,
    391, 392, 393, 394, 395, 396, 397, 398, 400, 402, 403, 404, 405, 406, 407,
    409, 410, 411, 412, 414, 415, 416, 417, 418, 419, 421, 422, 423, 424, 425,
    426, 427, 428, 429, 430, 432, 433, 434, 435, 436, 437, 438, 439, 440, 443,
    444, 445, 447, 448, 450, 451, 453, 454, 455, 456, 457, 458, 459, 460, 461,
    1000367,
]

# Fixed buffer size
BUFFER_SIZE = 8

# Predefined nc values to limit compilation (nt will be remainder)
# Total points per trial will be nc + BUFFER_SIZE + nt
# Using powers of 2 and nearby values for computational efficiency
PREDEFINED_NC_VALUES = [
    4, 8, 12, 16, 24, 32, 48, 64, 96, 128, 192, 256,
    # Also include some in-between values for flexibility
    20, 40, 80, 160, 320
]


# Data loading utilities
def _data_path(base_data_path: Path, *xs: str) -> str:
    """Get the path of a data file."""
    return str(base_data_path / Path(*xs))


def _parse_trial(path: str) -> Optional[Dict[str, Any]]:
    """Parse a single EEG trial file (.rd format), supporting optional .gz compression."""
    import gzip
    def _open(p):
        return gzip.open(p, 'rt') if p.endswith('.gz') else open(p, 'r')
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        with _open(path) as fh:
            sites = np.genfromtxt(fh, usecols=1, dtype=str)
        with _open(path) as fh:
            data = np.genfromtxt(fh, usecols=(2, 3))
    
    if len(sites) == 0:
        return None
    
    x, y = data[:, 0], data[:, 1]
    
    inds = np.where(np.diff(x) < 0)[0] + 1
    sites = [sites[0]] + [sites[i] for i in inds]
    x = x[:inds[0]] / 256.0  # Sampled at 256 Hz
    y = np.stack(np.split(y, inds), axis=1)
    
    return {"df": pd.DataFrame(y, index=pd.Index(x, name="time"), columns=sites)}


def _extract_trials(subject_path: str) -> Dict[int, Dict[str, Any]]:
    """Extract all trials for a subject (recursively), supporting .rd.* and .rd.*.gz."""
    paths = glob.glob(os.path.join(subject_path, "**", "*.rd.*"), recursive=True)
    
    trial_numbers = [int(p.split(".rd.")[1]) for p in paths]
    
    with multiprocessing.Pool(processes=max(1, multiprocessing.cpu_count() // 2)) as pool:
        parsed_trials = pool.map(_parse_trial, paths)
    
    return {k: v for k, v in zip(trial_numbers, parsed_trials) if v is not None}


def _extract_subjects(path: str) -> Dict[int, Dict[str, Any]]:
    """Extract all subjects from a directory.

    Handles both already-extracted subject folders (co*/) and subject archives
    (co*.tar.gz or co*.tar) by extracting archives into co*/ subfolders on demand.
    """
    entries = sorted(glob.glob(os.path.join(path, "co*")))

    # First, ensure any subject archives are extracted into folders
    for entry in entries:
        if os.path.isdir(entry):
            continue
        base = os.path.basename(entry)
        lower = base.lower()
        if lower.endswith(".tar.gz") or lower.endswith(".tar"):
            # Destination folder without extensions
            name = base.split(".tar")[0]
            dest_dir = os.path.join(path, name)
            if not os.path.isdir(dest_dir):
                print(f"Extracting subject archive {base} -> {name}/")
                os.makedirs(dest_dir, exist_ok=True)
                with tarfile.open(entry, "r:*") as tar:
                    tar.extractall(dest_dir)

    # Refresh list now that archives may be extracted
    subject_dirs = [p for p in glob.glob(os.path.join(path, "co*")) if os.path.isdir(p)]

    type_map = {"co2a": "2a", "co2c": "2c", "co3a": "3a", "co3c": "3c"}

    subjects = {}
    for subject_path in subject_dirs:
        dirname = os.path.split(subject_path)[-1]
        # Skip if name doesn't have expected prefix
        prefix = dirname[:4]
        if prefix not in type_map:
            continue
        # Extract trailing 7-digit identifier if present
        number_str = dirname[-7:]
        try:
            number = int(number_str)
        except ValueError:
            # Fallback: try to parse any trailing digits
            digits = ''.join(ch for ch in dirname if ch.isdigit())
            if not digits:
                continue
            number = int(digits[-7:])

        subject = {
            "type": type_map[prefix],
            "trials": _extract_trials(subject_path),
        }
        subjects[number] = subject

    return subjects


def _load_full_eeg_data(base_data_path: Path) -> Dict[int, Dict[str, Any]]:
    """Load the full EEG dataset."""
    cache_file = _data_path(base_data_path, "full.pickle")
    
    # Check if cached parsed data exists
    if os.path.exists(cache_file):
        print("Loading cached EEG data...")
        with open(cache_file, "rb") as f:
            data = pickle.load(f)
        # Guard against an empty/invalid cache (e.g., created before extraction)
        if isinstance(data, dict) and len(data) > 0:
            return data
        else:
            print("Cached EEG data is empty; will re-parse from raw files...")
    
    # Check if raw data directory exists and is non-empty
    full_data_path = _data_path(base_data_path, "full")
    if not os.path.exists(full_data_path) or not os.listdir(full_data_path):
        raise FileNotFoundError(
            f"EEG raw data directory missing or empty: {full_data_path}\n"
            f"Run 'uv run python download_eeg_data.py' to extract into data/eeg/full."
        )
    
    # Parse data
    print("Parsing EEG data for first time use. This may take a while...")
    data = _extract_subjects(full_data_path)
    
    # Cache the parsed data
    with open(cache_file, "wb") as f:
        pickle.dump(data, f)
    
    print("EEG data parsing complete and cached.")
    return data


class EEGSampler:
    """
    EEG data sampler that loads and processes EEG trials for neural process training.
    
    This implementation is standalone and doesn't depend on wbml or neuralprocesses libraries.
    
    Args:
        data_path: Path to EEG data directory
        subset: Data subset - "train", "cv", or "eval"
        mode: Task mode - "interpolation", "forecasting", or "random" (reconstruction currently disabled)
        batch_size: Number of trials per batch
        num_tasks: Total number of tasks to generate per epoch
        total_points: Total number of points per trial (nc + 8 + nt)
        nc_idx: Index into PREDEFINED_NC_VALUES, or None for random
        device: Computation device
        dtype: Tensor data type
        seed: Random seed for reproducibility
    """
    
    def __init__(
        self,
        data_path: Optional[str] = None,
        subset: str = "train",
        mode: str = "random",
        batch_size: int = 16,
        num_tasks: int = 2**14,
        total_points: int = 256,
        nc_idx: Optional[int] = None,
        device: str = "cpu",
        dtype: torch.dtype = torch.float32,
        seed: int = 0,
    ):
        self.data_path = Path(data_path) if data_path else None
        self.subset = subset
        self.mode = mode
        self.batch_size = batch_size
        self.num_tasks = num_tasks
        self.num_batches = num_tasks // batch_size
        self.total_points = total_points
        self.nc_idx = nc_idx
        self.device = device
        self.dtype = dtype
        self.seed = seed
        
        # Set up random state
        self.rng = np.random.RandomState(seed)
        torch.manual_seed(seed)
        
        # EEG channels we're interested in
        self.channels = ["FZ", "F1", "F2", "F3", "F4", "F5", "F6"]
        self.dim_x = 1  # Time dimension
        self.dim_y = len(self.channels)  # Number of channels
        
        # Load trials for this subset
        self.trials = self._load_trials(self.data_path)
        self._trials_i = 0
        
        # Track used combinations for debugging
        self.used_combinations: Set[Tuple[int, int, int]] = set()
        
    def _get_nc_nt(self) -> Tuple[int, int]:
        """Get nc and nt values. nb is always 8."""
        # Filter nc values that would leave room for at least 1 target point
        valid_nc_values = [nc for nc in PREDEFINED_NC_VALUES 
                          if nc + BUFFER_SIZE < self.total_points]
        
        if not valid_nc_values:
            raise ValueError(f"total_points ({self.total_points}) too small for any nc + buffer={BUFFER_SIZE}")
        
        # Select nc
        if self.nc_idx is not None:
            nc = valid_nc_values[self.nc_idx % len(valid_nc_values)]
        else:
            nc = valid_nc_values[self.rng.randint(0, len(valid_nc_values))]
        
        # Calculate nt as remainder
        nt = self.total_points - nc - BUFFER_SIZE
        
        return nc, nt

    def _load_trials(self, base_data_path: Path) -> List[Dict[str, Any]]:
        """
        Load real EEG trials from the UCI EEG dataset.
        
        Returns:
            List of trial dictionaries with 'time' and 'data' keys
        """
        # Load the full EEG dataset
        print(f"Loading EEG data for {self.subset} subset...")
        full_data = _load_full_eeg_data(base_data_path)

        # Determine available subjects and create deterministic splits
        available_subjects = sorted(full_data.keys())
        rng = np.random.RandomState(99)
        rng.shuffle(available_subjects)
        if self.subset == "eval":
            subjects = available_subjects[:10]
        elif self.subset == "cv":
            subjects = available_subjects[10:20]
        elif self.subset == "train":
            subjects = available_subjects[20:]
        else:
            raise ValueError(f"Unknown subset: {self.subset}")
        
        # Collect trials from selected subjects
        trials = []
        for subject_id in subjects:
            if subject_id not in full_data:
                print(f"Warning: Subject {subject_id} not found in data")
                continue
            
            subject_data = full_data[subject_id]
            
            # Get all trials for this subject
            for trial_num, trial_info in subject_data["trials"].items():
                trial_df = trial_info["df"]
                
                # Check if trial has the channels we need
                available_channels = list(trial_df.columns)
                missing_channels = [ch for ch in self.channels if ch not in available_channels]
                
                if missing_channels:
                    # Skip trials that don't have all required channels
                    continue
                
                # Extract only the channels we need
                channel_data = trial_df[self.channels]
                
                # Check if there's enough data and no NaN values
                if len(channel_data) < self.total_points or channel_data.isna().any().any():
                    continue
                
                # Check if data is non-zero (some trials have all zeros)
                if np.abs(channel_data.values).sum() == 0:
                    continue
                
                trials.append({
                    'subject_id': subject_id,
                    'trial_num': trial_num,
                    'time': np.array(channel_data.index),
                    'data': channel_data
                })
        
        if not trials:
            raise ValueError(f"No valid trials found for {self.subset} subset")
        
        print(f"Loaded {len(trials)} valid trials for {self.subset} subset")
        
        # Shuffle trials
        self.rng.shuffle(trials)
        return trials
    
    def _get_subset_subjects(self) -> List[int]:
        """Get subject IDs for the specified subset."""
        # Use fixed seed for consistent train/cv/eval splits
        rng = np.random.RandomState(99)
        shuffled_subjects = _EEG_ALL_SUBJECTS.copy()
        rng.shuffle(shuffled_subjects)
        
        # Split subjects: 10 eval, 10 cv, rest train
        if self.subset == "eval":
            return shuffled_subjects[:10]
        elif self.subset == "cv":
            return shuffled_subjects[10:20]
        elif self.subset == "train":
            return shuffled_subjects[20:]
        else:
            raise ValueError(f"Unknown subset: {self.subset}")
    
    def generate_batch(self) -> DataAttr:
        """
        Generate a batch of EEG data with context, buffer, and target sets.
        
        Returns:
            DataAttr object with xc, yc, xb, yb, xt, yt attributes
        """
        # Get split sizes for this batch
        nc, nt = self._get_nc_nt()
        nb = BUFFER_SIZE
        self.used_combinations.add((nc, nb, nt))
        
        # Collect batch data
        batch_x = []
        batch_y = []
        
        for _ in range(self.batch_size):
            if self._trials_i >= len(self.trials):
                # Reached end of trials, shuffle and reset
                self.rng.shuffle(self.trials)
                self._trials_i = 0
            
            trial = self.trials[self._trials_i]
            self._trials_i += 1
            
            # Extract exactly total_points from the trial
            trial_length = len(trial['time'])
            if trial_length < self.total_points:
                # Skip trials that are too short (shouldn't happen with our filtering)
                continue
            else:
                # Random contiguous segment
                start_idx = self.rng.randint(0, trial_length - self.total_points + 1)
                indices = np.arange(start_idx, start_idx + self.total_points)
            
            # Extract data
            time = trial['time'][indices]
            data = trial['data'].iloc[indices].values  # (T, channels)
            
            # Normalize time to [0, 1] range
            time_normalized = (time - time[0]) / (time[-1] - time[0])
            
            batch_x.append(torch.tensor(time_normalized, dtype=self.dtype, device=self.device))
            batch_y.append(torch.tensor(data.T, dtype=self.dtype, device=self.device))  # (channels, T)
        
        # Stack batch: x is (B, T), y is (B, channels, T)
        x = torch.stack(batch_x).unsqueeze(1)  # (B, 1, T)
        y = torch.stack(batch_y)  # (B, channels, T)
        
        # Apply task
        batch = self._apply_task(x, y, nc, nb, nt)
        
        return batch
    
    def _apply_task(
        self, 
        x: torch.Tensor, 
        y: torch.Tensor,
        nc: int,
        nb: int,
        nt: int
    ) -> DataAttr:
        """
        Apply the specified task to create context, buffer, and target sets.
        All operations are vectorized - no loops over batch dimension.
        
        Args:
            x: Time points tensor (B, 1, T) where T = nc + nb + nt
            y: EEG data tensor (B, channels, T)
            nc: Number of context points
            nb: Number of buffer points (always 8)
            nt: Number of target points
            
        Returns:
            DataAttr with context, buffer, and target sets
        """
        B, C, T = y.shape
        
        # Determine task mode
        mode = self.mode
        if mode == "random":
            # Reconstruction commented out for now due to dimension mismatch concerns
            # mode = random.choice(["interpolation", "forecasting", "reconstruction"])
            mode = random.choice(["interpolation", "forecasting"])
        
        batch = DataAttr()
        
        if mode == "interpolation":
            # Random interpolation: shuffle all indices
            # Create random permutations for each batch element
            indices = torch.stack([torch.randperm(T, device=x.device) for _ in range(B)])
            
            # Split indices
            context_indices = indices[:, :nc]
            buffer_indices = indices[:, nc:nc+nb]
            target_indices = indices[:, nc+nb:]
            
            # Gather data using advanced indexing
            # For x: expand to match y's channel dimension
            x_expanded = x.expand(-1, C, -1)  # (B, C, T)
            
            # Context
            batch.xc = torch.gather(x_expanded, 2, context_indices.unsqueeze(1).expand(-1, C, -1))
            batch.xc = batch.xc[:, 0:1, :].transpose(1, 2)  # (B, nc, 1)
            batch.yc = torch.gather(y, 2, context_indices.unsqueeze(1).expand(-1, C, -1))
            batch.yc = batch.yc.transpose(1, 2)  # (B, nc, C)
            
            # Buffer (same structure as target - all channels)
            batch.xb = torch.gather(x_expanded, 2, buffer_indices.unsqueeze(1).expand(-1, C, -1))
            batch.xb = batch.xb[:, 0:1, :].transpose(1, 2)  # (B, nb, 1)
            batch.yb = torch.gather(y, 2, buffer_indices.unsqueeze(1).expand(-1, C, -1))
            batch.yb = batch.yb.transpose(1, 2)  # (B, nb, C)
            
            # Target
            batch.xt = torch.gather(x_expanded, 2, target_indices.unsqueeze(1).expand(-1, C, -1))
            batch.xt = batch.xt[:, 0:1, :].transpose(1, 2)  # (B, nt, 1)
            batch.yt = torch.gather(y, 2, target_indices.unsqueeze(1).expand(-1, C, -1))
            batch.yt = batch.yt.transpose(1, 2)  # (B, nt, C)
            
        elif mode == "forecasting":
            indices = torch.argsort(x, dim=2) # [B, 1, T]
            
            # Split indices
            context_indices = indices[:, :, :nc]
            buffer_indices = indices[:, :, nc:nc+nb]
            target_indices = indices[:, :, nc+nb:]
            # Sequential split - buffer is first nb points of the future
            # Context: first nc points
            _xc = torch.gather(x, 2, context_indices) # [B, 1, nc]
            _yc = torch.gather(y, 2, context_indices.expand(-1, C, -1)) # [B, C, nc]
            batch.xc = _xc.transpose(1, 2)  # (B, nc, 1)
            batch.yc = _yc.transpose(1, 2)  # (B, nc, C)
            
            # Buffer: next nb points (first part of future)
            _xb = torch.gather(x, 2, buffer_indices) # [B, 1, nb]
            _yb = torch.gather(y, 2, buffer_indices.expand(-1, C, -1)) # [B, C, nb]
            batch.xb = _xb.transpose(1, 2)  # (B, nb, 1)
            batch.yb = _yb.transpose(1, 2)  # (B, nb, C)

            # Target: remaining points
            _xt = torch.gather(x, 2, target_indices) # [B, 1, nt]
            _yt = torch.gather(y, 2, target_indices.expand(-1, C, -1)) # [B, C, nt]
            batch.xt = _xt.transpose(1, 2)  # (B, nt, 1)
            batch.yt = _yt.transpose(1, 2)  # (B, nt, C)
            assert False, (_xc, _xb, _xt)
        # Reconstruction task commented out for now due to dimension compatibility concerns
        # elif mode == "reconstruction":
        #     # Reconstruction: predict one channel from a certain time onwards
        #     target_channel = self.rng.randint(0, C)
        #     
        #     # Context: first nc points (all channels)
        #     batch.xc = x[:, :, :nc].transpose(1, 2)  # (B, nc, 1)
        #     batch.yc = y[:, :, :nc].transpose(1, 2)  # (B, nc, C)
        #     
        #     # Buffer: next nb points (only target channel - same as what we predict)
        #     batch.xb = x[:, :, nc:nc+nb].transpose(1, 2)  # (B, nb, 1)
        #     batch.yb = y[:, target_channel:target_channel+1, nc:nc+nb].transpose(1, 2)  # (B, nb, 1)
        #     
        #     # Target: remaining points (only target channel)
        #     batch.xt = x[:, :, nc+nb:].transpose(1, 2)  # (B, nt, 1)
        #     batch.yt = y[:, target_channel:target_channel+1, nc+nb:].transpose(1, 2)  # (B, nt, 1)
        
        # Add mask for compatibility
        batch.mask = torch.ones(B, nt, dtype=torch.int8, device=x.device)
        
        return batch
    
    def get_used_combinations(self) -> List[Tuple[int, int, int]]:
        """Return the set of (nc, nb, nt) combinations used so far."""
        return sorted(list(self.used_combinations))
    
    def __iter__(self):
        """Make the sampler iterable."""
        for _ in range(self.num_batches):
            yield self.generate_batch()
    
    def __len__(self):
        """Return number of batches per epoch."""
        return self.num_batches
