#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SEED-VD数据集生成器
基于Video2EEG-SGGN-Diffusion模型生成针对不同视频和真实脑电的数据集

核心功能:
1. 处理SEED-VD原始数据（视频和EEG）
2. 使用Video2EEG-SGGN-Diffusion模型生成对应的EEG数据
3. 构建完整的训练/验证/测试数据集
4. 支持多种数据增强策略
5. 生成质量评估和可视化

作者: 算法工程师
日期: 2025年1月12日
"""

import os
import sys
import json
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal, stats
from pathlib import Path
import logging
from typing import Dict, List, Tuple, Optional, Union
import warnings
from collections import defaultdict
import pandas as pd
from tqdm import tqdm
import cv2
from datetime import datetime
import shutil

warnings.filterwarnings('ignore')

# 导入模型和推理引擎
from video2eeg_sggn_diffusion_model import Video2EEGSGGNDiffusion, create_video2eeg_sggn_diffusion
from improved_inference_sggn import ImprovedSGGNModelLoader, EEGQualityEvaluator

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

class SEEDVDDataProcessor:
    """
    SEED-VD数据处理器
    负责处理原始的SEED-VD数据集
    """
    
    def __init__(self, 
                 video_dir: str = "/data0/GYF-projects/EEG2Video/dataset/Video",
                 eeg_dir: str = "/data0/GYF-projects/EEG2Video/data/Rawf_200Hz",
                 output_dir: str = "./seed_vd_processed_data"):
        """
        初始化数据处理器
        
        Args:
            video_dir: 视频数据目录
            eeg_dir: EEG数据目录
            output_dir: 输出目录
        """
        self.video_dir = Path(video_dir)
        self.eeg_dir = Path(eeg_dir)
        self.output_dir = Path(output_dir)
        
        # 创建输出目录
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # SEED-VD数据集参数
        self.video_fps = 24
        self.video_resolution = (1920, 1080)
        self.eeg_sampling_rate = 200
        self.eeg_channels = 62
        self.video_duration_per_block = 600  # 10分钟
        self.eeg_duration_per_block = 520    # 8分40秒
        
        logger.info(f"SEED-VD数据处理器初始化完成")
        logger.info(f"视频目录: {self.video_dir}")
        logger.info(f"EEG目录: {self.eeg_dir}")
        logger.info(f"输出目录: {self.output_dir}")
    
    def load_eeg_data(self, subject_id: int) -> np.ndarray:
        """
        加载指定被试的EEG数据
        
        Args:
            subject_id: 被试ID
        
        Returns:
            EEG数据 (7, 62, 104000)
        """
        eeg_file = self.eeg_dir / f"sub{subject_id}.npy"
        if not eeg_file.exists():
            # 检查是否有session2数据
            eeg_file = self.eeg_dir / f"sub{subject_id}_session2.npy"
            if not eeg_file.exists():
                raise FileNotFoundError(f"找不到被试{subject_id}的EEG数据")
        
        eeg_data = np.load(eeg_file)
        logger.info(f"加载被试{subject_id}的EEG数据: {eeg_data.shape}")
        
        return eeg_data
    
    def extract_video_frames(self, video_id: int, 
                           start_time: float = 0, 
                           duration: float = 10,
                           target_fps: int = 8,
                           target_size: Tuple[int, int] = (224, 224)) -> np.ndarray:
        """
        从视频中提取帧序列
        
        Args:
            video_id: 视频ID (1-7)
            start_time: 开始时间（秒）
            duration: 持续时间（秒）
            target_fps: 目标帧率
            target_size: 目标尺寸
        
        Returns:
            视频帧数组 (T, H, W, C)
        """
        video_files = {
            1: "1st_10min.mp4",
            2: "2nd_10min.mp4",
            3: "3rd_10min.mp4",
            4: "4th_10min.mp4",
            5: "5th_10min.mp4",
            6: "6th_10min.mp4",
            7: "7th_10min.mp4"
        }
        
        if video_id not in video_files:
            raise ValueError(f"无效的视频ID: {video_id}")
        
        video_path = self.video_dir / video_files[video_id]
        if not video_path.exists():
            raise FileNotFoundError(f"找不到视频文件: {video_path}")
        
        # 使用OpenCV读取视频
        cap = cv2.VideoCapture(str(video_path))
        
        # 获取视频信息
        original_fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # 计算帧采样间隔
        frame_interval = int(original_fps / target_fps)
        start_frame = int(start_time * original_fps)
        end_frame = int((start_time + duration) * original_fps)
        
        frames = []
        frame_idx = 0
        
        while True:
            ret, frame = cap.read()
            if not ret or frame_idx >= end_frame:
                break
            
            if frame_idx >= start_frame and (frame_idx - start_frame) % frame_interval == 0:
                # 调整尺寸
                frame = cv2.resize(frame, target_size)
                # BGR转RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
            
            frame_idx += 1
        
        cap.release()
        
        if not frames:
            raise ValueError(f"未能从视频{video_id}中提取到帧")
        
        video_frames = np.array(frames, dtype=np.float32) / 255.0
        logger.info(f"从视频{video_id}提取帧: {video_frames.shape}")
        
        return video_frames
    
    def segment_eeg_data(self, eeg_data: np.ndarray, 
                        block_id: int,
                        segment_duration: float = 10,
                        overlap_ratio: float = 0.5) -> List[np.ndarray]:
        """
        分割EEG数据为固定长度的片段
        
        Args:
            eeg_data: EEG数据 (7, 62, 104000)
            block_id: 块ID (0-6)
            segment_duration: 片段持续时间（秒）
            overlap_ratio: 重叠比例
        
        Returns:
            EEG片段列表
        """
        if block_id >= eeg_data.shape[0]:
            raise ValueError(f"块ID {block_id} 超出范围")
        
        block_eeg = eeg_data[block_id]  # (62, 104000)
        
        # 使用模型实际生成的信号长度（200个样本点）而不是基于时间计算
        # 这样确保生成的EEG和参考EEG维度一致
        segment_samples = 200  # 固定使用模型实际生成的长度
        step_samples = int(segment_samples * (1 - overlap_ratio))
        
        segments = []
        start_idx = 0
        
        while start_idx + segment_samples <= block_eeg.shape[1]:
            segment = block_eeg[:, start_idx:start_idx + segment_samples]
            segments.append(segment)
            start_idx += step_samples
        
        logger.info(f"块{block_id}分割为{len(segments)}个片段（每片段{segment_samples}个样本点）")
        
        return segments
    
    def create_video_eeg_pairs(self, 
                              subject_ids: List[int] = None,
                              video_ids: List[int] = None,
                              segment_duration: float = 10,
                              samples_per_video: int = 20) -> List[Dict]:
        """
        创建视频-EEG配对数据
        
        Args:
            subject_ids: 被试ID列表
            video_ids: 视频ID列表
            segment_duration: 片段持续时间
            samples_per_video: 每个视频的样本数
        
        Returns:
            配对数据列表
        """
        if subject_ids is None:
            # 自动检测可用的被试
            subject_ids = []
            for i in range(1, 21):  # sub1-sub20
                if (self.eeg_dir / f"sub{i}.npy").exists():
                    subject_ids.append(i)
                elif (self.eeg_dir / f"sub{i}_session2.npy").exists():
                    subject_ids.append(i)
        
        if video_ids is None:
            video_ids = list(range(1, 8))  # 1-7个视频
        
        pairs = []
        
        for subject_id in tqdm(subject_ids, desc="处理被试"):
            try:
                # 加载EEG数据
                eeg_data = self.load_eeg_data(subject_id)
                
                for video_id in video_ids:
                    try:
                        # 为每个视频创建多个样本
                        for sample_idx in range(samples_per_video):
                            # 随机选择时间段
                            max_start_time = 600 - segment_duration  # 10分钟 - 片段长度
                            start_time = np.random.uniform(0, max_start_time)
                            
                            # 提取视频帧
                            video_frames = self.extract_video_frames(
                                video_id, start_time, segment_duration
                            )
                            
                            # 选择对应的EEG块（视频1-7对应EEG块0-6）
                            block_id = video_id - 1
                            
                            # 使用固定的200个样本点，与模型生成的EEG长度一致
                            eeg_start_sample = int(start_time * self.eeg_sampling_rate)
                            segment_samples = 200  # 固定使用200个样本点
                            eeg_end_sample = eeg_start_sample + segment_samples
                            
                            if eeg_end_sample <= eeg_data.shape[2]:
                                eeg_segment = eeg_data[block_id, :, eeg_start_sample:eeg_end_sample]
                                
                                pair = {
                                    'subject_id': subject_id,
                                    'video_id': video_id,
                                    'sample_idx': sample_idx,
                                    'start_time': start_time,
                                    'video_frames': video_frames,
                                    'eeg_data': eeg_segment,
                                    'metadata': {
                                        'video_shape': video_frames.shape,
                                        'eeg_shape': eeg_segment.shape,
                                        'duration': segment_duration,
                                        'sampling_rate': self.eeg_sampling_rate
                                    }
                                }
                                
                                pairs.append(pair)
                    
                    except Exception as e:
                        logger.warning(f"处理视频{video_id}时出错: {e}")
                        continue
            
            except Exception as e:
                logger.warning(f"处理被试{subject_id}时出错: {e}")
                continue
        
        logger.info(f"创建了{len(pairs)}个视频-EEG配对")
        
        return pairs
    
    def save_processed_data(self, pairs: List[Dict], 
                          train_ratio: float = 0.7,
                          val_ratio: float = 0.15,
                          test_ratio: float = 0.15):
        """
        保存处理后的数据
        
        Args:
            pairs: 配对数据列表
            train_ratio: 训练集比例
            val_ratio: 验证集比例
            test_ratio: 测试集比例
        """
        # 随机打乱数据
        np.random.shuffle(pairs)
        
        # 分割数据集
        n_total = len(pairs)
        n_train = int(n_total * train_ratio)
        n_val = int(n_total * val_ratio)
        
        train_pairs = pairs[:n_train]
        val_pairs = pairs[n_train:n_train + n_val]
        test_pairs = pairs[n_train + n_val:]
        
        # 保存数据
        splits = {
            'train': train_pairs,
            'val': val_pairs,
            'test': test_pairs
        }
        
        for split_name, split_pairs in splits.items():
            split_dir = self.output_dir / split_name
            split_dir.mkdir(exist_ok=True)
            
            for i, pair in enumerate(tqdm(split_pairs, desc=f"保存{split_name}数据")):
                filename = f"subject_{pair['subject_id']}_video_{pair['video_id']}_sample_{pair['sample_idx']}.npz"
                filepath = split_dir / filename
                
                np.savez_compressed(
                    filepath,
                    video_data=pair['video_frames'],
                    eeg_data=pair['eeg_data'],
                    subject_id=pair['subject_id'],
                    video_id=pair['video_id'],
                    sample_idx=pair['sample_idx'],
                    start_time=pair['start_time'],
                    metadata=json.dumps(pair['metadata'])
                )
        
        # 保存数据集分割信息
        split_info = {
            'train': [f"subject_{p['subject_id']}_video_{p['video_id']}_sample_{p['sample_idx']}.npz" for p in train_pairs],
            'val': [f"subject_{p['subject_id']}_video_{p['video_id']}_sample_{p['sample_idx']}.npz" for p in val_pairs],
            'test': [f"subject_{p['subject_id']}_video_{p['video_id']}_sample_{p['sample_idx']}.npz" for p in test_pairs]
        }
        
        with open(self.output_dir / "dataset_split.json", 'w') as f:
            json.dump(split_info, f, indent=2)
        
        # 保存数据集统计信息
        stats = {
            'total_samples': n_total,
            'train_samples': len(train_pairs),
            'val_samples': len(val_pairs),
            'test_samples': len(test_pairs),
            'subjects': list(set([p['subject_id'] for p in pairs])),
            'videos': list(set([p['video_id'] for p in pairs])),
            'data_shape': {
                'video': pairs[0]['video_frames'].shape if pairs else None,
                'eeg': pairs[0]['eeg_data'].shape if pairs else None
            },
            'processing_time': datetime.now().isoformat()
        }
        
        with open(self.output_dir / "dataset_stats.json", 'w') as f:
            json.dump(stats, f, indent=2)
        
        logger.info(f"数据保存完成:")
        logger.info(f"  训练集: {len(train_pairs)} 样本")
        logger.info(f"  验证集: {len(val_pairs)} 样本")
        logger.info(f"  测试集: {len(test_pairs)} 样本")
        logger.info(f"  总计: {n_total} 样本")

class SEEDVDDatasetGenerator:
    """
    SEED-VD数据集生成器
    使用Video2EEG-SGGN-Diffusion模型生成EEG数据
    """
    
    def __init__(self,
                 model_path: str,
                 processed_data_dir: str,
                 output_dir: str = "./seed_vd_generated_dataset",
                 device: str = 'auto'):
        """
        初始化数据集生成器
        
        Args:
            model_path: 训练好的模型路径
            processed_data_dir: 处理后的数据目录
            output_dir: 输出目录
            device: 设备类型
        """
        self.model_path = model_path
        self.processed_data_dir = Path(processed_data_dir)
        self.output_dir = Path(output_dir)
        self.device = device
        
        # 创建输出目录
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 加载模型
        self.model_loader = ImprovedSGGNModelLoader(model_path, device)
        
        # 创建质量评估器
        self.evaluator = EEGQualityEvaluator()
        
        logger.info(f"SEED-VD数据集生成器初始化完成")
    
    def load_processed_data(self, split: str = 'test') -> List[Dict]:
        """
        加载处理后的数据
        
        Args:
            split: 数据集分割
        
        Returns:
            数据列表
        """
        split_dir = self.processed_data_dir / split
        if not split_dir.exists():
            raise FileNotFoundError(f"找不到{split}数据目录: {split_dir}")
        
        data_files = list(split_dir.glob("*.npz"))
        data_list = []
        
        for file_path in tqdm(data_files, desc=f"加载{split}数据"):
            try:
                data = np.load(file_path)
                
                item = {
                    'file_path': file_path,
                    'video_data': data['video_data'],
                    'eeg_data': data['eeg_data'],
                    'subject_id': int(data['subject_id']),
                    'video_id': int(data['video_id']),
                    'sample_idx': int(data['sample_idx']),
                    'start_time': float(data['start_time']),
                    'metadata': json.loads(str(data['metadata']))
                }
                
                data_list.append(item)
                
            except Exception as e:
                logger.warning(f"加载文件{file_path}时出错: {e}")
                continue
        
        logger.info(f"加载{split}数据完成: {len(data_list)}个样本")
        
        return data_list
    
    def generate_eeg_dataset(self, 
                           split: str = 'test',
                           num_samples: int = None,
                           batch_size: int = 1) -> Dict[str, List]:
        """
        生成EEG数据集
        
        Args:
            split: 数据集分割
            num_samples: 生成样本数（None表示全部）
            batch_size: 批处理大小
        
        Returns:
            生成结果字典
        """
        # 加载数据
        data_list = self.load_processed_data(split)
        
        if num_samples is not None:
            data_list = data_list[:num_samples]
        
        results = {
            'generated_eegs': [],
            'reference_eegs': [],
            'video_data': [],
            'quality_metrics': [],
            'inference_times': [],
            'metadata': []
        }
        
        logger.info(f"开始生成EEG数据集，样本数: {len(data_list)}")
        
        for i, item in enumerate(tqdm(data_list, desc="生成EEG")):
            try:
                # 准备视频数据
                video_frames = torch.from_numpy(item['video_data']).unsqueeze(0)  # (1, T, H, W, C)
                video_frames = video_frames.permute(0, 1, 4, 2, 3)  # (1, T, C, H, W)
                
                # 记录推理时间
                start_time = time.time()
                
                # 生成EEG
                generated_eeg = self.model_loader.generate_eeg_stable(video_frames)
                
                inference_time = time.time() - start_time
                
                # 转换为numpy
                generated_eeg_np = generated_eeg.cpu().numpy()[0]  # (C, T)
                reference_eeg_np = item['eeg_data']  # (C, T)
                
                # 检查维度是否匹配
                if generated_eeg_np.shape != reference_eeg_np.shape:
                    logger.warning(f"样本{i}维度不匹配: 生成EEG {generated_eeg_np.shape} vs 参考EEG {reference_eeg_np.shape}")
                    # 调整维度以匹配
                    min_length = min(generated_eeg_np.shape[1], reference_eeg_np.shape[1])
                    generated_eeg_np = generated_eeg_np[:, :min_length]
                    reference_eeg_np = reference_eeg_np[:, :min_length]
                    logger.info(f"调整后维度: {generated_eeg_np.shape}")
                
                # 评估质量
                metrics = self.evaluator.evaluate_quality(
                    generated_eeg_np, reference_eeg_np
                )
                
                # 保存结果
                results['generated_eegs'].append(generated_eeg_np)
                results['reference_eegs'].append(reference_eeg_np)
                results['video_data'].append(item['video_data'])
                results['quality_metrics'].append(metrics)
                results['inference_times'].append(inference_time)
                results['metadata'].append({
                    'subject_id': item['subject_id'],
                    'video_id': item['video_id'],
                    'sample_idx': item['sample_idx'],
                    'start_time': item['start_time'],
                    'original_metadata': item['metadata']
                })
                
                logger.info(f"样本{i+1}/{len(data_list)} - MSE: {metrics['mse']:.6f}, 推理时间: {inference_time:.3f}s")
                
            except Exception as e:
                logger.error(f"生成样本{i}时出错: {e}")
                import traceback
                traceback.print_exc()
                # 添加空的结果以保持索引一致性
                results['quality_metrics'].append(None)
                results['inference_times'].append(0.0)
                continue
        
        logger.info(f"EEG数据集生成完成，共{len(results['generated_eegs'])}个样本")
        
        return results
    
    def save_generated_dataset(self, results: Dict[str, List], 
                             dataset_name: str = "seed_vd_generated"):
        """
        保存生成的数据集
        
        Args:
            results: 生成结果
            dataset_name: 数据集名称
        """
        # 创建数据集目录
        dataset_dir = self.output_dir / dataset_name
        dataset_dir.mkdir(exist_ok=True)
        
        # 保存数据
        data_file = dataset_dir / "generated_eeg_dataset.npz"
        np.savez_compressed(
            data_file,
            generated_eegs=np.array(results['generated_eegs']),
            reference_eegs=np.array(results['reference_eegs']),
            video_data=np.array(results['video_data']),
            inference_times=np.array(results['inference_times'])
        )
        
        # 保存质量指标
        metrics_file = dataset_dir / "quality_metrics.json"
        with open(metrics_file, 'w') as f:
            json.dump(results['quality_metrics'], f, indent=2)
        
        # 保存元数据
        metadata_file = dataset_dir / "metadata.json"
        with open(metadata_file, 'w') as f:
            json.dump(results['metadata'], f, indent=2)
        
        # 生成数据集统计报告
        self.generate_dataset_report(results, dataset_dir)
        
        logger.info(f"生成的数据集已保存到: {dataset_dir}")
    
    def generate_dataset_report(self, results: Dict[str, List], output_dir: Path):
        """
        生成数据集报告
        
        Args:
            results: 生成结果
            output_dir: 输出目录
        """
        # 计算统计信息
        num_samples = len(results['generated_eegs'])
        
        # 检查是否有有效的质量指标
        if num_samples > 0 and len(results['quality_metrics']) > 0:
            valid_metrics = [m for m in results['quality_metrics'] if m is not None and 'mse' in m]
            
            if valid_metrics:
                metrics_df = pd.DataFrame(valid_metrics)
                quality_statistics = {
                    'mse': {
                        'mean': float(metrics_df['mse'].mean()),
                        'std': float(metrics_df['mse'].std()),
                        'min': float(metrics_df['mse'].min()),
                        'max': float(metrics_df['mse'].max())
                    },
                    'mae': {
                        'mean': float(metrics_df['mae'].mean()),
                        'std': float(metrics_df['mae'].std()),
                        'min': float(metrics_df['mae'].min()),
                        'max': float(metrics_df['mae'].max())
                    },
                    'correlation': {
                        'mean': float(metrics_df['mean_correlation'].mean()),
                        'std': float(metrics_df['mean_correlation'].std()),
                        'min': float(metrics_df['mean_correlation'].min()),
                        'max': float(metrics_df['mean_correlation'].max())
                    }
                }
            else:
                quality_statistics = {
                    'mse': {'mean': 0, 'std': 0, 'min': 0, 'max': 0},
                    'mae': {'mean': 0, 'std': 0, 'min': 0, 'max': 0},
                    'correlation': {'mean': 0, 'std': 0, 'min': 0, 'max': 0}
                }
        else:
            quality_statistics = {
                'mse': {'mean': 0, 'std': 0, 'min': 0, 'max': 0},
                'mae': {'mean': 0, 'std': 0, 'min': 0, 'max': 0},
                'correlation': {'mean': 0, 'std': 0, 'min': 0, 'max': 0}
            }
        
        # 计算性能统计
        if len(results['inference_times']) > 0:
            performance_statistics = {
                'inference_time': {
                    'mean': float(np.mean(results['inference_times'])),
                    'std': float(np.std(results['inference_times'])),
                    'min': float(np.min(results['inference_times'])),
                    'max': float(np.max(results['inference_times']))
                }
            }
        else:
            performance_statistics = {
                'inference_time': {'mean': 0, 'std': 0, 'min': 0, 'max': 0}
            }
        
        report = {
            'dataset_info': {
                'total_samples': num_samples,
                'subjects': list(set([m['subject_id'] for m in results['metadata']])) if results['metadata'] else [],
                'videos': list(set([m['video_id'] for m in results['metadata']])) if results['metadata'] else [],
                'generation_time': datetime.now().isoformat()
            },
            'quality_statistics': quality_statistics,
            'performance_statistics': performance_statistics
        }
        
        # 保存报告
        report_file = output_dir / "dataset_report.json"
        with open(report_file, 'w') as f:
            json.dump(report, f, indent=2)
        
        # 生成可视化
        self.create_dataset_visualizations(results, output_dir)
        
        logger.info(f"数据集报告已生成: {report_file}")
    
    def create_dataset_visualizations(self, results: Dict[str, List], output_dir: Path):
        """
        创建数据集可视化
        
        Args:
            results: 生成结果
            output_dir: 输出目录
        """
        viz_dir = output_dir / "visualizations"
        viz_dir.mkdir(exist_ok=True)
        
        # 1. 质量指标分布
        metrics_df = pd.DataFrame(results['quality_metrics'])
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # MSE分布
        axes[0, 0].hist(metrics_df['mse'], bins=20, alpha=0.7, edgecolor='black')
        axes[0, 0].set_title('MSE Distribution')
        axes[0, 0].set_xlabel('MSE')
        axes[0, 0].set_ylabel('Frequency')
        
        # MAE分布
        axes[0, 1].hist(metrics_df['mae'], bins=20, alpha=0.7, edgecolor='black')
        axes[0, 1].set_title('MAE Distribution')
        axes[0, 1].set_xlabel('MAE')
        axes[0, 1].set_ylabel('Frequency')
        
        # 相关性分布
        axes[1, 0].hist(metrics_df['mean_correlation'], bins=20, alpha=0.7, edgecolor='black')
        axes[1, 0].set_title('Correlation Distribution')
        axes[1, 0].set_xlabel('Correlation')
        axes[1, 0].set_ylabel('Frequency')
        
        # 推理时间分布
        axes[1, 1].hist(results['inference_times'], bins=20, alpha=0.7, edgecolor='black')
        axes[1, 1].set_title('Inference Time Distribution')
        axes[1, 1].set_xlabel('Time (s)')
        axes[1, 1].set_ylabel('Frequency')
        
        plt.tight_layout()
        plt.savefig(viz_dir / "quality_distributions.png", dpi=300, bbox_inches='tight')
        plt.close()
        
        # 2. EEG信号对比示例
        num_examples = min(3, len(results['generated_eegs']))
        
        for i in range(num_examples):
            generated = results['generated_eegs'][i]
            reference = results['reference_eegs'][i]
            
            fig, axes = plt.subplots(4, 1, figsize=(15, 8))
            
            for ch in range(4):
                time_axis = np.arange(generated.shape[1]) / 200.0  # 200Hz采样率
                
                axes[ch].plot(time_axis, reference[ch], label='Reference EEG', alpha=0.7)
                axes[ch].plot(time_axis, generated[ch], label='Generated EEG', alpha=0.7)
                axes[ch].set_title(f'Channel {ch+1}')
                axes[ch].set_xlabel('Time (s)')
                axes[ch].set_ylabel('Amplitude')
                axes[ch].legend()
                axes[ch].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(viz_dir / f"eeg_comparison_example_{i+1}.png", dpi=300, bbox_inches='tight')
            plt.close()
        
        logger.info(f"可视化图表已保存到: {viz_dir}")

def main():
    """
    主函数
    """
    parser = argparse.ArgumentParser(description='SEED-VD数据集生成器')
    parser.add_argument('--mode', type=str, choices=['process', 'generate', 'both'], 
                       default='both', help='运行模式')
    parser.add_argument('--video_dir', type=str, 
                       default='/data0/GYF-projects/EEG2Video/dataset/Video',
                       help='视频数据目录')
    parser.add_argument('--eeg_dir', type=str,
                       default='/data0/GYF-projects/EEG2Video/data/Rawf_200Hz',
                       help='EEG数据目录')
    parser.add_argument('--model_path', type=str,
                       default='./sggn_training_output/best_model.pth',
                       help='训练好的模型路径')
    parser.add_argument('--processed_data_dir', type=str,
                       default='./seed_vd_processed_data',
                       help='处理后的数据目录')
    parser.add_argument('--output_dir', type=str,
                       default='./seed_vd_generated_dataset',
                       help='输出目录')
    parser.add_argument('--num_samples', type=int, default=None,
                       help='生成样本数（None表示全部）')
    parser.add_argument('--samples_per_video', type=int, default=10,
                       help='每个视频的样本数')
    parser.add_argument('--segment_duration', type=float, default=10.0,
                       help='片段持续时间（秒）')
    parser.add_argument('--device', type=str, default='auto',
                       help='设备类型')
    
    args = parser.parse_args()
    
    try:
        if args.mode in ['process', 'both']:
            logger.info("=== 开始处理SEED-VD原始数据 ===")
            
            # 创建数据处理器
            processor = SEEDVDDataProcessor(
                video_dir=args.video_dir,
                eeg_dir=args.eeg_dir,
                output_dir=args.processed_data_dir
            )
            
            # 创建视频-EEG配对
            pairs = processor.create_video_eeg_pairs(
                segment_duration=args.segment_duration,
                samples_per_video=args.samples_per_video
            )
            
            # 保存处理后的数据
            processor.save_processed_data(pairs)
            
            logger.info("SEED-VD数据处理完成")
        
        if args.mode in ['generate', 'both']:
            logger.info("=== 开始生成EEG数据集 ===")
            
            # 创建数据集生成器
            generator = SEEDVDDatasetGenerator(
                model_path=args.model_path,
                processed_data_dir=args.processed_data_dir,
                output_dir=args.output_dir,
                device=args.device
            )
            
            # 生成EEG数据集
            results = generator.generate_eeg_dataset(
                split='test',
                num_samples=args.num_samples
            )
            
            # 保存生成的数据集
            generator.save_generated_dataset(results)
            
            logger.info("EEG数据集生成完成")
        
        logger.info("=== 所有任务完成 ===")
        
        return 0
        
    except Exception as e:
        logger.error(f"执行过程中发生错误: {e}")
        import traceback
        traceback.print_exc()
        return 1

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)