#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SEED-VD数据集生成运行脚本
简化的启动脚本，支持配置文件和命令行参数

使用方法:
1. 仅处理数据: python run_seed_vd_generation.py --mode process
2. 仅生成数据集: python run_seed_vd_generation.py --mode generate
3. 完整流程: python run_seed_vd_generation.py --mode both
4. 使用配置文件: python run_seed_vd_generation.py --config seed_vd_config.yaml
5. 快速测试: python run_seed_vd_generation.py --quick_test

作者: 算法工程师
日期: 2025年1月12日
"""

import os
import sys
import yaml
import argparse
import logging
from pathlib import Path
from datetime import datetime

# 添加当前目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from seed_vd_dataset_generator import SEEDVDDataProcessor, SEEDVDDatasetGenerator

def load_config(config_path: str) -> dict:
    """
    加载配置文件
    
    Args:
        config_path: 配置文件路径
    
    Returns:
        配置字典
    """
    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    return config

def setup_logging(level: str = "INFO"):
    """
    设置日志
    
    Args:
        level: 日志级别
    """
    logging.basicConfig(
        level=getattr(logging, level.upper()),
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler(f'seed_vd_generation_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
        ]
    )

def run_data_processing(config: dict):
    """
    运行数据处理
    
    Args:
        config: 配置字典
    """
    logger = logging.getLogger(__name__)
    logger.info("=== 开始处理SEED-VD原始数据 ===")
    
    # 创建数据处理器
    processor = SEEDVDDataProcessor(
        video_dir=config['data_paths']['video_dir'],
        eeg_dir=config['data_paths']['eeg_dir'],
        output_dir=config['data_paths']['processed_data_dir']
    )
    
    # 创建视频-EEG配对
    pairs = processor.create_video_eeg_pairs(
        subject_ids=config['data_processing']['subject_ids'],
        video_ids=config['data_processing']['video_ids'],
        segment_duration=config['data_processing']['segment_duration'],
        samples_per_video=config['data_processing']['samples_per_video']
    )
    
    # 保存处理后的数据
    processor.save_processed_data(
        pairs,
        train_ratio=config['data_processing']['split_ratios']['train'],
        val_ratio=config['data_processing']['split_ratios']['val'],
        test_ratio=config['data_processing']['split_ratios']['test']
    )
    
    logger.info("SEED-VD数据处理完成")
    return True

def run_dataset_generation(config: dict):
    """
    运行数据集生成
    
    Args:
        config: 配置字典
    """
    logger = logging.getLogger(__name__)
    logger.info("=== 开始生成EEG数据集 ===")
    
    # 检查模型文件是否存在
    model_path = config['model']['model_path']
    if not os.path.exists(model_path):
        logger.error(f"模型文件不存在: {model_path}")
        logger.info("请先训练模型或提供正确的模型路径")
        return False
    
    # 创建数据集生成器
    generator = SEEDVDDatasetGenerator(
        model_path=model_path,
        processed_data_dir=config['data_paths']['processed_data_dir'],
        output_dir=config['data_paths']['output_dir'],
        device=config['model']['device']
    )
    
    # 生成EEG数据集
    results = generator.generate_eeg_dataset(
        split=config['generation']['split'],
        num_samples=config['generation']['num_samples'],
        batch_size=config['generation']['batch_size']
    )
    
    # 保存生成的数据集
    generator.save_generated_dataset(
        results,
        dataset_name=config['generation']['dataset_name']
    )
    
    logger.info("EEG数据集生成完成")
    return True

def run_quick_test():
    """
    运行快速测试
    """
    logger = logging.getLogger(__name__)
    logger.info("=== 运行快速测试 ===")
    
    # 快速测试配置
    test_config = {
        'data_paths': {
            'video_dir': '/data0/GYF-projects/EEG2Video/dataset/Video',
            'eeg_dir': '/data0/GYF-projects/EEG2Video/data/Rawf_200Hz',
            'processed_data_dir': './seed_vd_test_data',
            'output_dir': './seed_vd_test_output'
        },
        'data_processing': {
            'subject_ids': [1, 2],  # 仅使用前两个被试
            'video_ids': [1, 2],    # 仅使用前两个视频
            'segment_duration': 5.0,  # 更短的片段
            'samples_per_video': 2,   # 更少的样本
            'overlap_ratio': 0.5,
            'split_ratios': {
                'train': 0.6,
                'val': 0.2,
                'test': 0.2
            }
        },
        'model': {
            'model_path': './sggn_training_output/best_model.pth',
            'device': 'auto'
        },
        'generation': {
            'num_samples': 2,  # 仅生成2个样本
            'batch_size': 1,
            'split': 'test',
            'dataset_name': 'seed_vd_test'
        }
    }
    
    try:
        # 运行数据处理
        success = run_data_processing(test_config)
        if not success:
            return False
        
        # 检查是否有模型文件
        if os.path.exists(test_config['model']['model_path']):
            # 运行数据集生成
            success = run_dataset_generation(test_config)
            if not success:
                return False
        else:
            logger.warning("模型文件不存在，跳过数据集生成步骤")
            logger.info("请先训练模型后再运行完整的生成流程")
        
        logger.info("快速测试完成")
        return True
        
    except Exception as e:
        logger.error(f"快速测试失败: {e}")
        import traceback
        traceback.print_exc()
        return False

def main():
    """
    主函数
    """
    parser = argparse.ArgumentParser(
        description='SEED-VD数据集生成运行脚本',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
使用示例:
  python run_seed_vd_generation.py --mode process
  python run_seed_vd_generation.py --mode generate
  python run_seed_vd_generation.py --mode both
  python run_seed_vd_generation.py --config seed_vd_config.yaml
  python run_seed_vd_generation.py --quick_test
        """
    )
    
    parser.add_argument('--mode', type=str, 
                       choices=['process', 'generate', 'both'],
                       default='both',
                       help='运行模式: process(仅处理), generate(仅生成), both(完整流程)')
    
    parser.add_argument('--config', type=str,
                       default='seed_vd_config.yaml',
                       help='配置文件路径')
    
    parser.add_argument('--quick_test', action='store_true',
                       help='运行快速测试（使用少量数据）')
    
    parser.add_argument('--log_level', type=str,
                       choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
                       default='INFO',
                       help='日志级别')
    
    # 覆盖配置的命令行参数
    parser.add_argument('--video_dir', type=str, help='视频数据目录')
    parser.add_argument('--eeg_dir', type=str, help='EEG数据目录')
    parser.add_argument('--model_path', type=str, help='模型路径')
    parser.add_argument('--output_dir', type=str, help='输出目录')
    parser.add_argument('--num_samples', type=int, help='生成样本数')
    parser.add_argument('--device', type=str, help='设备类型')
    
    args = parser.parse_args()
    
    # 设置日志
    setup_logging(args.log_level)
    logger = logging.getLogger(__name__)
    
    try:
        # 快速测试模式
        if args.quick_test:
            success = run_quick_test()
            return 0 if success else 1
        
        # 加载配置文件
        if not os.path.exists(args.config):
            logger.error(f"配置文件不存在: {args.config}")
            logger.info("请创建配置文件或使用 --quick_test 进行快速测试")
            return 1
        
        config = load_config(args.config)
        
        # 命令行参数覆盖配置文件
        if args.video_dir:
            config['data_paths']['video_dir'] = args.video_dir
        if args.eeg_dir:
            config['data_paths']['eeg_dir'] = args.eeg_dir
        if args.model_path:
            config['model']['model_path'] = args.model_path
        if args.output_dir:
            config['data_paths']['output_dir'] = args.output_dir
        if args.num_samples:
            config['generation']['num_samples'] = args.num_samples
        if args.device:
            config['model']['device'] = args.device
        
        logger.info(f"使用配置文件: {args.config}")
        logger.info(f"运行模式: {args.mode}")
        
        # 检查数据目录
        video_dir = Path(config['data_paths']['video_dir'])
        eeg_dir = Path(config['data_paths']['eeg_dir'])
        
        if not video_dir.exists():
            logger.error(f"视频数据目录不存在: {video_dir}")
            return 1
        
        if not eeg_dir.exists():
            logger.error(f"EEG数据目录不存在: {eeg_dir}")
            return 1
        
        # 执行任务
        success = True
        
        if args.mode in ['process', 'both']:
            success = run_data_processing(config)
            if not success:
                return 1
        
        if args.mode in ['generate', 'both']:
            success = run_dataset_generation(config)
            if not success:
                return 1
        
        logger.info("=== 所有任务完成 ===")
        return 0
        
    except KeyboardInterrupt:
        logger.info("用户中断执行")
        return 1
    except Exception as e:
        logger.error(f"执行过程中发生错误: {e}")
        import traceback
        traceback.print_exc()
        return 1

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)