#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
增强数据预处理管道
实现视频-EEG特征对齐的数据预处理
包括时间对齐、特征标准化和数据增强

作者: 算法工程师
日期: 2025年1月12日
"""

import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
from scipy import signal
from scipy.interpolate import interp1d
from sklearn.preprocessing import StandardScaler, RobustScaler
import matplotlib.pyplot as plt
from typing import Tuple, Dict, List, Optional, Union
import logging
import json
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EnhancedVideoEEGDataProcessor:
    """
    增强的视频-EEG数据预处理器
    实现精确的时间对齐、特征标准化和数据增强
    """
    
    def __init__(self, 
                 eeg_data_path: str = "/data0/GYF-projects/EEG2Video/data/Rawf_200Hz",
                 video_data_path: str = "/data0/GYF-projects/EEG2Video/EEG2Video/data/final_data/video",
                 demographic_path: str = "/data0/GYF-projects/EEG2Video/data/demographic information.xlsx",
                 meta_info_path: str = "/data0/GYF-projects/EEG2Video/dataset/meta_info",
                 target_fps: int = 25,
                 target_eeg_freq: int = 200,
                 time_window: float = 8.0,
                 overlap_ratio: float = 0.5):
        """
        初始化数据预处理器
        
        Args:
            eeg_data_path: EEG数据路径
            video_data_path: 视频数据路径
            demographic_path: 人口统计信息路径
            meta_info_path: 元信息路径
            target_fps: 目标视频帧率
            target_eeg_freq: 目标EEG采样频率
            time_window: 时间窗口长度(秒)
            overlap_ratio: 重叠比例
        """
        self.eeg_data_path = Path(eeg_data_path)
        self.video_data_path = Path(video_data_path)
        self.demographic_path = Path(demographic_path)
        self.meta_info_path = Path(meta_info_path)
        
        self.target_fps = target_fps
        self.target_eeg_freq = target_eeg_freq
        self.time_window = time_window
        self.overlap_ratio = overlap_ratio
        
        # 计算窗口参数
        self.eeg_window_size = int(target_eeg_freq * time_window)
        self.video_window_size = int(target_fps * time_window)
        self.step_size = int(self.eeg_window_size * (1 - overlap_ratio))
        
        # 初始化标准化器
        self.eeg_scaler = RobustScaler()
        self.video_scaler = StandardScaler()
        
        # 加载元信息
        self._load_metadata()
        
        logger.info(f"数据预处理器初始化完成")
        logger.info(f"EEG窗口大小: {self.eeg_window_size}, 视频窗口大小: {self.video_window_size}")
    
    def _load_metadata(self):
        """加载元数据信息"""
        try:
            # 加载人口统计信息
            if self.demographic_path.exists():
                try:
                    self.demographic_info = pd.read_excel(self.demographic_path)
                    logger.info(f"加载人口统计信息: {len(self.demographic_info)} 条记录")
                except Exception as e:
                    logger.warning(f"加载人口统计信息失败 (可能缺少openpyxl): {e}")
                    self.demographic_info = None
            else:
                logger.warning(f"人口统计信息文件不存在: {self.demographic_path}")
                self.demographic_info = None
            
            # 加载视频元信息
            self.video_meta = {}
            if self.meta_info_path.exists():
                for meta_file in self.meta_info_path.glob("*.npy"):
                    meta_name = meta_file.stem.replace("All_video_", "")
                    self.video_meta[meta_name] = np.load(meta_file)
                    logger.info(f"加载视频元信息: {meta_name}, 形状: {self.video_meta[meta_name].shape}")
            
        except Exception as e:
            logger.error(f"加载元数据失败: {e}")
            self.demographic_info = None
            self.video_meta = {}
    
    def load_eeg_data(self, subject_id: int, video_id: int = 1) -> Optional[np.ndarray]:
        """
        加载EEG数据
        
        Args:
            subject_id: 被试ID
            video_id: 视频ID (用于选择对应的EEG片段)
            
        Returns:
            EEG数据数组，形状为 (channels, time_points)
        """
        try:
            eeg_path = self.eeg_data_path / f"sub{subject_id}.npy"
            
            if not eeg_path.exists():
                logger.error(f"EEG文件不存在: {eeg_path}")
                return None
            
            eeg_data = np.load(eeg_path)
            logger.info(f"加载EEG数据: {eeg_path}, 原始形状: {eeg_data.shape}")
            
            # 根据数据结构检查结果，Rawf_200Hz的数据形状为 (7, 62, 104000)
            # 维度0: 7个视频片段
            # 维度1: 62个EEG通道
            # 维度2: 104000个时间点 (约520秒，200Hz采样率)
            
            if eeg_data.ndim == 3 and eeg_data.shape[0] >= video_id:
                # 选择对应视频的EEG数据
                selected_eeg = eeg_data[video_id - 1]  # video_id从1开始
                logger.info(f"选择视频{video_id}对应的EEG数据，形状: {selected_eeg.shape}")
                return selected_eeg  # 形状为 (62, 104000)
            else:
                logger.error(f"EEG数据维度不匹配或视频ID超出范围: {eeg_data.shape}, video_id: {video_id}")
                return None
            
        except Exception as e:
            logger.error(f"加载EEG数据失败: {e}")
            return None
    
    def load_video_data(self, video_id: int) -> Optional[np.ndarray]:
        """
        加载视频数据
        
        Args:
            video_id: 视频ID
            
        Returns:
            视频帧数组，形状为 (frames, height, width, channels)
        """
        try:
            video_path = self.video_data_path / f"{video_id}.mp4"
            
            if not video_path.exists():
                logger.error(f"视频文件不存在: {video_path}")
                return None
            
            # 使用OpenCV读取视频
            cap = cv2.VideoCapture(str(video_path))
            frames = []
            
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                # 转换BGR到RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
            
            cap.release()
            
            if len(frames) == 0:
                logger.error(f"视频文件为空: {video_path}")
                return None
            
            video_data = np.array(frames)
            logger.info(f"加载视频数据: {video_path}, 形状: {video_data.shape}")
            
            return video_data
            
        except Exception as e:
            logger.error(f"加载视频数据失败: {e}")
            return None
    
    def temporal_alignment(self, 
                          eeg_data: np.ndarray, 
                          video_data: np.ndarray,
                          eeg_sampling_rate: int = 200,
                          video_fps: int = 25) -> Tuple[np.ndarray, np.ndarray]:
        """
        精确的时间对齐
        
        Args:
            eeg_data: EEG数据 (channels, time_points)
            video_data: 视频数据 (frames, height, width, channels)
            eeg_sampling_rate: EEG采样率
            video_fps: 视频帧率
            
        Returns:
            对齐后的EEG和视频数据
        """
        try:
            # 计算时间长度
            eeg_duration = eeg_data.shape[1] / eeg_sampling_rate
            video_duration = video_data.shape[0] / video_fps
            
            # 取较短的时间长度
            target_duration = min(eeg_duration, video_duration)
            
            # 重采样到目标频率
            target_eeg_points = int(target_duration * self.target_eeg_freq)
            target_video_frames = int(target_duration * self.target_fps)
            
            # EEG重采样
            if eeg_data.shape[1] != target_eeg_points:
                eeg_aligned = np.zeros((eeg_data.shape[0], target_eeg_points))
                for ch in range(eeg_data.shape[0]):
                    f = interp1d(np.linspace(0, 1, eeg_data.shape[1]), 
                               eeg_data[ch], kind='cubic')
                    eeg_aligned[ch] = f(np.linspace(0, 1, target_eeg_points))
            else:
                eeg_aligned = eeg_data[:, :target_eeg_points]
            
            # 视频重采样
            if video_data.shape[0] != target_video_frames:
                indices = np.linspace(0, video_data.shape[0]-1, target_video_frames).astype(int)
                video_aligned = video_data[indices]
            else:
                video_aligned = video_data[:target_video_frames]
            
            logger.info(f"时间对齐完成: EEG {eeg_aligned.shape}, Video {video_aligned.shape}")
            
            return eeg_aligned, video_aligned
            
        except Exception as e:
            logger.error(f"时间对齐失败: {e}")
            return eeg_data, video_data
    
    def feature_normalization(self, 
                            eeg_data: np.ndarray, 
                            video_data: np.ndarray,
                            fit_scalers: bool = True) -> Tuple[np.ndarray, np.ndarray]:
        """
        特征标准化
        
        Args:
            eeg_data: EEG数据
            video_data: 视频数据
            fit_scalers: 是否拟合标准化器
            
        Returns:
            标准化后的数据
        """
        try:
            # EEG标准化 (按通道)
            eeg_normalized = eeg_data.copy()
            if fit_scalers:
                # 重塑为 (samples, features) 进行拟合
                eeg_reshaped = eeg_data.T  # (time_points, channels)
                self.eeg_scaler.fit(eeg_reshaped)
            
            eeg_reshaped = eeg_data.T
            eeg_scaled = self.eeg_scaler.transform(eeg_reshaped)
            eeg_normalized = eeg_scaled.T
            
            # 视频标准化 (按像素)
            video_normalized = video_data.copy().astype(np.float32)
            original_shape = video_data.shape
            
            if fit_scalers:
                # 重塑为 (samples, features)
                video_reshaped = video_data.reshape(-1, video_data.shape[-1])
                self.video_scaler.fit(video_reshaped)
            
            video_reshaped = video_data.reshape(-1, video_data.shape[-1])
            video_scaled = self.video_scaler.transform(video_reshaped)
            video_normalized = video_scaled.reshape(original_shape)
            
            logger.info(f"特征标准化完成")
            
            return eeg_normalized, video_normalized
            
        except Exception as e:
            logger.error(f"特征标准化失败: {e}")
            return eeg_data, video_data
    
    def eeg_data_augmentation(self, eeg_data: np.ndarray) -> List[np.ndarray]:
        """
        EEG数据增强
        
        Args:
            eeg_data: 原始EEG数据
            
        Returns:
            增强后的EEG数据列表
        """
        augmented_data = [eeg_data]  # 包含原始数据
        
        try:
            # 1. 时间偏移
            shift_samples = int(0.1 * self.target_eeg_freq)  # 0.1秒偏移
            if eeg_data.shape[1] > 2 * shift_samples:
                shifted_data = eeg_data[:, shift_samples:-shift_samples]
                augmented_data.append(shifted_data)
            
            # 2. 幅度缩放
            scale_factors = [0.9, 1.1]
            for scale in scale_factors:
                scaled_data = eeg_data * scale
                augmented_data.append(scaled_data)
            
            # 3. 高斯噪声
            noise_std = 0.05 * np.std(eeg_data)
            noise = np.random.normal(0, noise_std, eeg_data.shape)
            noisy_data = eeg_data + noise
            augmented_data.append(noisy_data)
            
            # 4. 频域滤波
            # 低通滤波
            b, a = signal.butter(4, 40, btype='low', fs=self.target_eeg_freq)
            filtered_data = np.zeros_like(eeg_data)
            for ch in range(eeg_data.shape[0]):
                filtered_data[ch] = signal.filtfilt(b, a, eeg_data[ch])
            augmented_data.append(filtered_data)
            
            # 5. 通道dropout (随机置零部分通道)
            dropout_ratio = 0.1
            num_channels = eeg_data.shape[0]
            num_dropout = int(num_channels * dropout_ratio)
            dropout_channels = np.random.choice(num_channels, num_dropout, replace=False)
            
            dropout_data = eeg_data.copy()
            dropout_data[dropout_channels] = 0
            augmented_data.append(dropout_data)
            
            logger.info(f"EEG数据增强完成，生成 {len(augmented_data)} 个样本")
            
        except Exception as e:
            logger.error(f"EEG数据增强失败: {e}")
        
        return augmented_data
    
    def video_data_augmentation(self, video_data: np.ndarray) -> List[np.ndarray]:
        """
        视频数据增强
        
        Args:
            video_data: 原始视频数据
            
        Returns:
            增强后的视频数据列表
        """
        augmented_data = [video_data]  # 包含原始数据
        
        try:
            # 1. 亮度调整
            brightness_factors = [0.8, 1.2]
            for factor in brightness_factors:
                bright_data = np.clip(video_data * factor, 0, 255).astype(np.uint8)
                augmented_data.append(bright_data)
            
            # 2. 对比度调整
            contrast_factors = [0.8, 1.2]
            for factor in contrast_factors:
                mean_val = np.mean(video_data)
                contrast_data = np.clip((video_data - mean_val) * factor + mean_val, 0, 255).astype(np.uint8)
                augmented_data.append(contrast_data)
            
            # 3. 高斯噪声
            noise_std = 5.0
            noise = np.random.normal(0, noise_std, video_data.shape)
            noisy_data = np.clip(video_data + noise, 0, 255).astype(np.uint8)
            augmented_data.append(noisy_data)
            
            # 4. 时间子采样
            if video_data.shape[0] > 20:
                subsample_indices = np.linspace(0, video_data.shape[0]-1, 
                                              video_data.shape[0]//2).astype(int)
                subsampled_data = video_data[subsample_indices]
                augmented_data.append(subsampled_data)
            
            # 5. 空间裁剪和缩放
            h, w = video_data.shape[1:3]
            crop_size = int(min(h, w) * 0.8)
            start_h = (h - crop_size) // 2
            start_w = (w - crop_size) // 2
            
            cropped_data = video_data[:, start_h:start_h+crop_size, start_w:start_w+crop_size]
            # 缩放回原尺寸
            resized_data = np.zeros_like(video_data)
            for i, frame in enumerate(cropped_data):
                resized_frame = cv2.resize(frame, (w, h))
                resized_data[i] = resized_frame
            augmented_data.append(resized_data)
            
            logger.info(f"视频数据增强完成，生成 {len(augmented_data)} 个样本")
            
        except Exception as e:
            logger.error(f"视频数据增强失败: {e}")
        
        return augmented_data
    
    def create_aligned_windows(self, 
                              eeg_data: np.ndarray, 
                              video_data: np.ndarray) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        """
        创建对齐的时间窗口
        
        Args:
            eeg_data: EEG数据 (channels, time_points)
            video_data: 视频数据 (frames, height, width, channels)
            
        Returns:
            EEG窗口列表和视频窗口列表
        """
        eeg_windows = []
        video_windows = []
        
        try:
            logger.info(f"输入数据形状 - EEG: {eeg_data.shape}, Video: {video_data.shape}")
            logger.info(f"窗口参数 - EEG窗口: {self.eeg_window_size}, Video窗口: {self.video_window_size}, 步长: {self.step_size}")
            
            # 检查数据是否足够创建窗口
            if eeg_data.shape[1] < self.eeg_window_size:
                logger.warning(f"EEG数据长度 {eeg_data.shape[1]} 小于窗口大小 {self.eeg_window_size}")
                # 使用整个EEG数据作为一个窗口
                eeg_windows.append(eeg_data)
                # 对应的视频窗口
                if video_data.shape[0] >= self.video_window_size:
                    video_windows.append(video_data[:self.video_window_size])
                else:
                    video_windows.append(video_data)
                logger.info(f"创建 1 个完整数据窗口")
                return eeg_windows, video_windows
            
            if video_data.shape[0] < self.video_window_size:
                logger.warning(f"视频帧数 {video_data.shape[0]} 小于窗口大小 {self.video_window_size}")
                # 调整视频窗口大小为实际帧数
                actual_video_window = video_data.shape[0]
                actual_eeg_window = int(actual_video_window * self.target_eeg_freq / self.target_fps)
                
                if eeg_data.shape[1] >= actual_eeg_window:
                    eeg_windows.append(eeg_data[:, :actual_eeg_window])
                    video_windows.append(video_data)
                    logger.info(f"创建 1 个调整后的窗口 - EEG: {actual_eeg_window}, Video: {actual_video_window}")
                return eeg_windows, video_windows
            
            # 计算可以创建的窗口数量
            max_eeg_windows = (eeg_data.shape[1] - self.eeg_window_size) // self.step_size + 1
            video_step = int(self.step_size * self.target_fps / self.target_eeg_freq)
            max_video_windows = (video_data.shape[0] - self.video_window_size) // video_step + 1
            
            num_windows = min(max_eeg_windows, max_video_windows)
            logger.info(f"可创建窗口数 - EEG: {max_eeg_windows}, Video: {max_video_windows}, 实际: {num_windows}")
            
            if num_windows <= 0:
                # 如果无法创建滑动窗口，创建一个固定窗口
                logger.warning("无法创建滑动窗口，使用固定窗口")
                eeg_end = min(self.eeg_window_size, eeg_data.shape[1])
                video_end = min(self.video_window_size, video_data.shape[0])
                
                eeg_windows.append(eeg_data[:, :eeg_end])
                video_windows.append(video_data[:video_end])
                logger.info(f"创建 1 个固定窗口")
                return eeg_windows, video_windows
            
            for i in range(num_windows):
                # EEG窗口
                eeg_start = i * self.step_size
                eeg_end = eeg_start + self.eeg_window_size
                eeg_window = eeg_data[:, eeg_start:eeg_end]
                
                # 视频窗口
                video_start = i * video_step
                video_end = video_start + self.video_window_size
                video_window = video_data[video_start:video_end]
                
                eeg_windows.append(eeg_window)
                video_windows.append(video_window)
            
            logger.info(f"创建 {len(eeg_windows)} 个对齐窗口")
            
        except Exception as e:
            logger.error(f"创建对齐窗口失败: {e}")
            # 创建一个默认窗口作为fallback
            if eeg_data.size > 0 and video_data.size > 0:
                eeg_end = min(self.eeg_window_size, eeg_data.shape[1])
                video_end = min(self.video_window_size, video_data.shape[0])
                eeg_windows.append(eeg_data[:, :eeg_end])
                video_windows.append(video_data[:video_end])
                logger.info(f"创建 1 个fallback窗口")
        
        return eeg_windows, video_windows
    
    def process_subject_data(self, 
                           subject_id: int, 
                           video_id: int,
                           apply_augmentation: bool = True) -> Dict:
        """
        处理单个被试的数据
        
        Args:
            subject_id: 被试ID
            video_id: 视频ID
            apply_augmentation: 是否应用数据增强
            
        Returns:
            处理后的数据字典
        """
        logger.info(f"开始处理被试 {subject_id}, 视频 {video_id}")
        
        # 加载原始数据
        eeg_data = self.load_eeg_data(subject_id, video_id)
        video_data = self.load_video_data(video_id)
        
        if eeg_data is None or video_data is None:
            logger.error(f"数据加载失败")
            return {}
        
        # 时间对齐
        eeg_aligned, video_aligned = self.temporal_alignment(eeg_data, video_data)
        
        # 特征标准化
        eeg_normalized, video_normalized = self.feature_normalization(
            eeg_aligned, video_aligned, fit_scalers=True)
        
        # 创建对齐窗口
        eeg_windows, video_windows = self.create_aligned_windows(
            eeg_normalized, video_normalized)
        
        result = {
            'subject_id': subject_id,
            'video_id': video_id,
            'eeg_windows': eeg_windows,
            'video_windows': video_windows,
            'eeg_shape': eeg_normalized.shape,
            'video_shape': video_normalized.shape,
            'num_windows': len(eeg_windows)
        }
        
        # 数据增强
        if apply_augmentation and len(eeg_windows) > 0:
            augmented_eeg = []
            augmented_video = []
            
            for eeg_win, video_win in zip(eeg_windows, video_windows):
                # EEG增强
                eeg_aug_list = self.eeg_data_augmentation(eeg_win)
                video_aug_list = self.video_data_augmentation(video_win)
                
                # 配对增强数据
                min_aug = min(len(eeg_aug_list), len(video_aug_list))
                for j in range(min_aug):
                    augmented_eeg.append(eeg_aug_list[j])
                    augmented_video.append(video_aug_list[j])
            
            result['augmented_eeg'] = augmented_eeg
            result['augmented_video'] = augmented_video
            result['total_samples'] = len(augmented_eeg)
            
            logger.info(f"数据增强完成，总样本数: {len(augmented_eeg)}")
        
        # 添加元信息
        if self.demographic_info is not None:
            try:
                # 假设人口统计信息中有subject_id列
                subject_info = self.demographic_info[self.demographic_info.iloc[:, 0] == subject_id]
                if not subject_info.empty:
                    result['demographic'] = subject_info.iloc[0].to_dict()
            except Exception as e:
                logger.warning(f"添加人口统计信息失败: {e}")
        
        # 添加视频元信息
        try:
            if self.video_meta and video_id <= 7:  # 根据检查结果，有7个视频的元信息
                result['video_meta'] = {}
                for key, values in self.video_meta.items():
                    if len(values) >= video_id:
                        result['video_meta'][key] = values[video_id-1]
        except Exception as e:
            logger.warning(f"添加视频元信息失败: {e}")
        
        logger.info(f"被试 {subject_id} 数据处理完成")
        return result
    
    def batch_process(self, 
                     subject_video_pairs: List[Tuple[int, int]],
                     output_dir: str = "./processed_data",
                     apply_augmentation: bool = True) -> Dict:
        """
        批量处理数据
        
        Args:
            subject_video_pairs: (被试ID, 视频ID) 对列表
            output_dir: 输出目录
            apply_augmentation: 是否应用数据增强
            
        Returns:
            处理结果统计
        """
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        results = []
        statistics = {
            'total_pairs': len(subject_video_pairs),
            'successful': 0,
            'failed': 0,
            'total_windows': 0,
            'total_samples': 0
        }
        
        for i, (subject_id, video_id) in enumerate(subject_video_pairs):
            try:
                logger.info(f"处理进度: {i+1}/{len(subject_video_pairs)}")
                
                result = self.process_subject_data(
                    subject_id, video_id, apply_augmentation)
                
                if result:
                    # 保存处理结果
                    output_file = output_path / f"subject_{subject_id}_video_{video_id}.npz"
                    
                    # 确保数据格式一致性
                    try:
                        # 检查数据有效性
                        if len(result['eeg_windows']) == 0 or len(result['video_windows']) == 0:
                            logger.warning(f"被试 {subject_id}, 视频 {video_id}: 没有有效的窗口数据")
                            statistics['failed'] += 1
                            continue
                        
                        # 记录数据信息
                        logger.info(f"保存数据 - EEG窗口数: {len(result['eeg_windows'])}, Video窗口数: {len(result['video_windows'])}")
                        logger.info(f"EEG窗口形状示例: {result['eeg_windows'][0].shape if len(result['eeg_windows']) > 0 else 'None'}")
                        logger.info(f"Video窗口形状示例: {result['video_windows'][0].shape if len(result['video_windows']) > 0 else 'None'}")
                        
                        save_data = {
                            'subject_id': subject_id,
                            'video_id': video_id,
                            'num_windows': len(result['eeg_windows'])
                        }
                        
                        # 分别保存每个EEG和视频窗口，避免形状不一致问题
                        for i, (eeg_win, video_win) in enumerate(zip(result['eeg_windows'], result['video_windows'])):
                            save_data[f'eeg_window_{i}'] = eeg_win
                            save_data[f'video_window_{i}'] = video_win
                        
                        if apply_augmentation and 'augmented_eeg' in result:
                            for i, (aug_eeg, aug_video) in enumerate(zip(result['augmented_eeg'], result['augmented_video'])):
                                save_data[f'augmented_eeg_{i}'] = aug_eeg
                                save_data[f'augmented_video_{i}'] = aug_video
                            save_data['num_augmented'] = len(result['augmented_eeg'])
                        
                        np.savez_compressed(output_file, **save_data)
                        
                    except Exception as save_error:
                        logger.error(f"保存数据失败: {save_error}")
                        # 尝试保存基本信息
                        try:
                            basic_info = {
                                'subject_id': subject_id,
                                'video_id': video_id,
                                'num_eeg_windows': len(result['eeg_windows']),
                                'num_video_windows': len(result['video_windows']),
                                'eeg_shape': result['eeg_windows'][0].shape if len(result['eeg_windows']) > 0 else None,
                                'video_shape': result['video_windows'][0].shape if len(result['video_windows']) > 0 else None
                            }
                            
                            info_file = output_path / f"subject_{subject_id}_video_{video_id}_info.json"
                            with open(info_file, 'w', encoding='utf-8') as f:
                                json.dump(basic_info, f, indent=2, default=str)
                            
                            logger.info(f"保存基本信息到: {info_file}")
                            
                        except Exception as info_error:
                            logger.error(f"保存基本信息也失败: {info_error}")
                        
                        statistics['failed'] += 1
                        continue
                    
                    # 保存元信息
                    meta_file = output_path / f"subject_{subject_id}_video_{video_id}_meta.json"
                    meta_data = {
                        k: v for k, v in result.items() 
                        if k not in ['eeg_windows', 'video_windows', 'augmented_eeg', 'augmented_video']
                    }
                    
                    with open(meta_file, 'w', encoding='utf-8') as f:
                        json.dump(meta_data, f, indent=2, ensure_ascii=False, default=str)
                    
                    results.append(result)
                    statistics['successful'] += 1
                    statistics['total_windows'] += result['num_windows']
                    if 'total_samples' in result:
                        statistics['total_samples'] += result['total_samples']
                    
                    logger.info(f"保存成功: {output_file}")
                else:
                    statistics['failed'] += 1
                    
            except Exception as e:
                logger.error(f"处理失败 - 被试 {subject_id}, 视频 {video_id}: {e}")
                import traceback
                logger.error(f"详细错误信息: {traceback.format_exc()}")
                statistics['failed'] += 1
        
        # 保存统计信息
        stats_file = output_path / "processing_statistics.json"
        with open(stats_file, 'w', encoding='utf-8') as f:
            json.dump(statistics, f, indent=2, ensure_ascii=False)
        
        logger.info(f"批量处理完成: 成功 {statistics['successful']}, 失败 {statistics['failed']}")
        logger.info(f"总窗口数: {statistics['total_windows']}, 总样本数: {statistics['total_samples']}")
        
        return statistics
    
    def create_training_dataset(self, 
                               processed_data_dir: str,
                               train_ratio: float = 0.8,
                               val_ratio: float = 0.1) -> Dict:
        """
        创建训练数据集
        
        Args:
            processed_data_dir: 处理后数据目录
            train_ratio: 训练集比例
            val_ratio: 验证集比例
            
        Returns:
            数据集划分信息
        """
        data_path = Path(processed_data_dir)
        
        # 收集所有数据文件
        data_files = list(data_path.glob("subject_*_video_*.npz"))
        
        if len(data_files) == 0:
            logger.error(f"未找到处理后的数据文件")
            return {}
        
        # 随机划分数据集
        np.random.shuffle(data_files)
        
        n_total = len(data_files)
        n_train = int(n_total * train_ratio)
        n_val = int(n_total * val_ratio)
        n_test = n_total - n_train - n_val
        
        train_files = data_files[:n_train]
        val_files = data_files[n_train:n_train+n_val]
        test_files = data_files[n_train+n_val:]
        
        # 保存数据集划分
        dataset_split = {
            'train': [str(f) for f in train_files],
            'val': [str(f) for f in val_files],
            'test': [str(f) for f in test_files],
            'statistics': {
                'total_files': n_total,
                'train_files': n_train,
                'val_files': n_val,
                'test_files': n_test
            }
        }
        
        split_file = data_path / "dataset_split.json"
        with open(split_file, 'w', encoding='utf-8') as f:
            json.dump(dataset_split, f, indent=2, ensure_ascii=False)
        
        logger.info(f"数据集划分完成: 训练 {n_train}, 验证 {n_val}, 测试 {n_test}")
        
        return dataset_split

class VideoEEGDataset(Dataset):
    """
    视频-EEG配对数据集
    """
    
    def __init__(self, data_files: List[str], use_augmentation: bool = True):
        """
        初始化数据集
        
        Args:
            data_files: 数据文件路径列表
            use_augmentation: 是否使用增强数据
        """
        self.data_files = data_files
        self.use_augmentation = use_augmentation
        
        # 预加载数据索引
        self.data_index = []
        for file_path in data_files:
            try:
                data = np.load(file_path)
                
                if use_augmentation and 'augmented_eeg' in data:
                    num_samples = len(data['augmented_eeg'])
                    for i in range(num_samples):
                        self.data_index.append((file_path, 'augmented', i))
                else:
                    num_samples = len(data['eeg_windows'])
                    for i in range(num_samples):
                        self.data_index.append((file_path, 'original', i))
                        
            except Exception as e:
                logger.error(f"加载数据文件失败: {file_path}, {e}")
        
        logger.info(f"数据集初始化完成，总样本数: {len(self.data_index)}")
    
    def __len__(self):
        return len(self.data_index)
    
    def __getitem__(self, idx):
        file_path, data_type, sample_idx = self.data_index[idx]
        
        try:
            data = np.load(file_path)
            
            if data_type == 'augmented' and 'augmented_eeg' in data:
                eeg_sample = data['augmented_eeg'][sample_idx]
                video_sample = data['augmented_video'][sample_idx]
            else:
                eeg_sample = data['eeg_windows'][sample_idx]
                video_sample = data['video_windows'][sample_idx]
            
            # 转换为torch张量
            eeg_tensor = torch.FloatTensor(eeg_sample)
            video_tensor = torch.FloatTensor(video_sample).permute(0, 3, 1, 2)  # (T, C, H, W)
            
            return {
                'eeg': eeg_tensor,
                'video': video_tensor,
                'file_path': file_path,
                'sample_idx': sample_idx
            }
            
        except Exception as e:
            logger.error(f"获取样本失败: {idx}, {e}")
            # 返回零张量作为fallback
            return {
                'eeg': torch.zeros(62, 1600),  # 62通道，8秒*200Hz=1600时间点
                'video': torch.zeros(200, 3, 224, 224),  # 8秒*25fps=200帧，3通道，224x224
                'file_path': file_path,
                'sample_idx': sample_idx
            }

def create_data_loaders(dataset_split: Dict, 
                       batch_size: int = 4,
                       num_workers: int = 4) -> Dict[str, DataLoader]:
    """
    创建数据加载器
    
    Args:
        dataset_split: 数据集划分信息
        batch_size: 批次大小
        num_workers: 工作进程数
        
    Returns:
        数据加载器字典
    """
    data_loaders = {}
    
    for split in ['train', 'val', 'test']:
        if split in dataset_split and len(dataset_split[split]) > 0:
            dataset = VideoEEGDataset(
                dataset_split[split], 
                use_augmentation=(split == 'train')
            )
            
            data_loader = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=(split == 'train'),
                num_workers=num_workers,
                pin_memory=True,
                drop_last=True
            )
            
            data_loaders[split] = data_loader
            logger.info(f"{split} 数据加载器创建完成，样本数: {len(dataset)}")
    
    return data_loaders

def visualize_alignment(eeg_data: np.ndarray, 
                       video_data: np.ndarray,
                       save_path: str = "alignment_visualization.png"):
    """
    可视化时间对齐结果
    
    Args:
        eeg_data: EEG数据
        video_data: 视频数据
        save_path: 保存路径
    """
    fig, axes = plt.subplots(3, 1, figsize=(15, 10))
    
    # EEG信号
    axes[0].plot(eeg_data[:5].T)  # 显示前5个通道
    axes[0].set_title('EEG Signals (First 5 Channels)')
    axes[0].set_xlabel('Time Points')
    axes[0].set_ylabel('Amplitude')
    
    # 视频帧亮度
    frame_brightness = np.mean(video_data, axis=(1, 2, 3))
    axes[1].plot(frame_brightness)
    axes[1].set_title('Video Frame Brightness')
    axes[1].set_xlabel('Frame Number')
    axes[1].set_ylabel('Average Brightness')
    
    # 时间对齐关系
    eeg_time = np.linspace(0, len(eeg_data[0])/200, len(eeg_data[0]))
    video_time = np.linspace(0, len(video_data)/25, len(video_data))
    
    axes[2].plot(eeg_time, np.mean(eeg_data, axis=0), label='EEG (avg)', alpha=0.7)
    axes[2].plot(video_time, frame_brightness/np.max(frame_brightness) * np.max(np.mean(eeg_data, axis=0)), 
                label='Video (normalized)', alpha=0.7)
    axes[2].set_title('Temporal Alignment')
    axes[2].set_xlabel('Time (seconds)')
    axes[2].set_ylabel('Normalized Amplitude')
    axes[2].legend()
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    logger.info(f"对齐可视化保存至: {save_path}")

def main():
    """
    主函数 - 演示数据预处理流程
    """
    logger.info("开始增强数据预处理流程")
    
    # 初始化处理器
    processor = EnhancedVideoEEGDataProcessor(
        target_fps=25,
        target_eeg_freq=200,
        time_window=8.0,
        overlap_ratio=0.5
    )
    
    # 根据数据检查结果，定义被试-视频对
    # EEG数据有7个视频片段，视频有10个文件
    subject_video_pairs = []
    
    # 为前10个被试创建与前7个视频的配对
    for subject_id in range(1, 11):  # 被试1-10
        for video_id in range(1, 8):  # 视频1-7 (对应EEG数据的7个片段)
            subject_video_pairs.append((subject_id, video_id))
    
    logger.info(f"创建了 {len(subject_video_pairs)} 个被试-视频配对")
    
    # 先测试处理少量数据
    test_pairs = subject_video_pairs[:5]  # 只处理前5对进行测试
    
    # 批量处理数据
    logger.info("开始批量处理数据...")
    statistics = processor.batch_process(
        test_pairs,
        output_dir="./enhanced_processed_data",
        apply_augmentation=True
    )
    
    if statistics['successful'] > 0:
        # 创建训练数据集
        logger.info("创建训练数据集...")
        dataset_split = processor.create_training_dataset(
            "./enhanced_processed_data",
            train_ratio=0.7,
            val_ratio=0.15
        )
        
        # 创建数据加载器
        if dataset_split:
            logger.info("创建数据加载器...")
            
            # 检查是否有有效的训练数据
            train_files = dataset_split.get('train', [])
            if len(train_files) > 0:
                # 检查训练文件中是否有有效样本
                total_samples = 0
                for file_path in train_files:
                    try:
                        data = np.load(file_path)
                        if 'eeg_windows' in data and len(data['eeg_windows']) > 0:
                            total_samples += len(data['eeg_windows'])
                    except Exception as e:
                        logger.warning(f"检查文件失败: {file_path}, {e}")
                
                logger.info(f"训练集总样本数: {total_samples}")
                
                if total_samples > 0:
                    data_loaders = create_data_loaders(dataset_split, batch_size=2)
                    
                    # 测试数据加载
                    if 'train' in data_loaders:
                        logger.info("测试数据加载...")
                        train_loader = data_loaders['train']
                        
                        for i, batch in enumerate(train_loader):
                            logger.info(f"批次 {i}: EEG {batch['eeg'].shape}, Video {batch['video'].shape}")
                            if i >= 2:  # 只测试前3个批次
                                break
                        
                        logger.info("数据预处理流程完成！")
                        logger.info(f"处理统计: {statistics}")
                    else:
                        logger.warning("未创建训练数据加载器")
                else:
                    logger.warning("训练集中没有有效样本，跳过数据加载器创建")
                    logger.info(f"处理统计: {statistics}")
            else:
                logger.warning("没有训练文件")
                logger.info(f"处理统计: {statistics}")
        else:
            logger.error("数据集划分失败")
    else:
        logger.error("没有成功处理的数据")

if __name__ == "__main__":
    main()