#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Video2EEG-SGGN-Diffusion 训练示例

本示例展示如何使用SGGN模型进行训练，包括：
1. 基础训练配置
2. 自定义参数设置
3. 多GPU分布式训练
4. 训练监控和可视化

作者: 算法工程师
日期: 2025年1月12日
"""

import os
import sys
import json
import torch
from pathlib import Path

# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

from train_sggn_diffusion import SGGNModelTrainer, create_default_config

def example_basic_training():
    """
    基础训练示例
    """
    print("=== 基础训练示例 ===")
    
    # 数据目录 - 请根据实际情况修改
    data_dir = "/data0/GYF-projects/EEG2Video/Vidoe2EEG/Video2EEG-diffusion/enhanced_processed_data"
    output_dir = "./example_training_output"
    
    # 检查数据目录是否存在
    if not os.path.exists(data_dir):
        print(f"警告: 数据目录不存在 {data_dir}")
        print("请修改data_dir为正确的数据路径")
        return
    
    # 使用默认配置
    config = create_default_config()
    
    # 调整配置用于快速测试
    config['training']['num_epochs'] = 5  # 减少epoch数用于测试
    config['training']['batch_size'] = 2  # 减小批次大小
    config['training']['eval_interval'] = 2
    config['training']['save_interval'] = 3
    
    try:
        # 创建训练器
        trainer = SGGNModelTrainer(
            config=config,
            data_dir=data_dir,
            output_dir=output_dir,
            use_distributed=False
        )
        
        print(f"开始训练，输出目录: {output_dir}")
        print(f"配置: {json.dumps(config, indent=2)}")
        
        # 开始训练
        trainer.train()
        
        print("训练完成！")
        print(f"检查点保存在: {output_dir}")
        
    except Exception as e:
        print(f"训练过程中发生错误: {e}")
        import traceback
        traceback.print_exc()

def example_custom_config_training():
    """
    自定义配置训练示例
    """
    print("\n=== 自定义配置训练示例 ===")
    
    # 数据目录
    data_dir = "/data0/GYF-projects/EEG2Video/Vidoe2EEG/Video2EEG-diffusion/enhanced_processed_data"
    output_dir = "./example_custom_training_output"
    
    # 自定义配置
    config = {
        'model': {
            'eeg_channels': 62,
            'signal_length': 200,
            'video_feature_dim': 512,
            'hidden_dim': 128,  # 减小隐藏层维度
            'num_diffusion_steps': 500,  # 减少扩散步数
            'frequency_bands': [
                (0.5, 4),    # Delta
                (4, 8),      # Theta
                (8, 13),     # Alpha
                (13, 30),    # Beta
                (30, 50)     # Gamma (减少到50Hz)
            ]
        },
        'dataset': {
            'use_graph_da': True,
            'augmentation_ratio': 0.5  # 增加数据增强比例
        },
        'training': {
            'num_epochs': 10,
            'batch_size': 1,  # 小批次大小适合内存限制
            'num_workers': 2,
            'save_interval': 5,
            'eval_interval': 3,
            'use_mixed_precision': True
        },
        'optimizer': {
            'type': 'AdamW',
            'learning_rate': 5e-5,  # 较小的学习率
            'weight_decay': 1e-4
        },
        'scheduler': {
            'type': 'CosineAnnealingLR'
        }
    }
    
    try:
        # 创建训练器
        trainer = SGGNModelTrainer(
            config=config,
            data_dir=data_dir,
            output_dir=output_dir,
            use_distributed=False
        )
        
        print(f"开始自定义配置训练，输出目录: {output_dir}")
        
        # 开始训练
        trainer.train()
        
        print("自定义配置训练完成！")
        
    except Exception as e:
        print(f"训练过程中发生错误: {e}")
        import traceback
        traceback.print_exc()

def example_resume_training():
    """
    恢复训练示例
    """
    print("\n=== 恢复训练示例 ===")
    
    # 检查点路径
    checkpoint_path = "./example_training_output/checkpoint_epoch_3.pth"
    
    if not os.path.exists(checkpoint_path):
        print(f"检查点文件不存在: {checkpoint_path}")
        print("请先运行基础训练示例生成检查点")
        return
    
    try:
        # 加载检查点
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        config = checkpoint['config']
        
        print(f"从检查点恢复训练: {checkpoint_path}")
        print(f"上次训练到第 {checkpoint['epoch']} 轮")
        
        # 修改配置继续训练
        config['training']['num_epochs'] = checkpoint['epoch'] + 5  # 再训练5轮
        
        # 数据目录
        data_dir = "/data0/GYF-projects/EEG2Video/Vidoe2EEG/Video2EEG-diffusion/enhanced_processed_data"
        output_dir = "./example_resume_training_output"
        
        # 创建训练器
        trainer = SGGNModelTrainer(
            config=config,
            data_dir=data_dir,
            output_dir=output_dir,
            use_distributed=False
        )
        
        # 加载模型权重
        trainer.model.load_state_dict(checkpoint['model_state_dict'])
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if trainer.scheduler and checkpoint['scheduler_state_dict']:
            trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        # 设置起始epoch
        trainer.current_epoch = checkpoint['epoch']
        trainer.best_val_loss = checkpoint['val_loss']
        
        print("开始恢复训练...")
        
        # 继续训练
        trainer.train()
        
        print("恢复训练完成！")
        
    except Exception as e:
        print(f"恢复训练过程中发生错误: {e}")
        import traceback
        traceback.print_exc()

def example_distributed_training_setup():
    """
    分布式训练设置示例
    """
    print("\n=== 分布式训练设置示例 ===")
    
    print("分布式训练命令示例:")
    print("")
    print("# 单机多GPU (4个GPU)")
    print("torchrun --nproc_per_node=4 train_sggn_diffusion.py \\")
    print("    --data_dir /path/to/your/data \\")
    print("    --output_dir ./distributed_training_output \\")
    print("    --config config.json \\")
    print("    --distributed")
    print("")
    print("# 多机多GPU (2台机器，每台4个GPU)")
    print("# 在主节点 (rank 0) 运行:")
    print("torchrun --nnodes=2 --node_rank=0 --master_addr=192.168.1.100 \\")
    print("    --master_port=12345 --nproc_per_node=4 train_sggn_diffusion.py \\")
    print("    --data_dir /path/to/your/data \\")
    print("    --output_dir ./distributed_training_output \\")
    print("    --config config.json \\")
    print("    --distributed")
    print("")
    print("# 在从节点 (rank 1) 运行:")
    print("torchrun --nnodes=2 --node_rank=1 --master_addr=192.168.1.100 \\")
    print("    --master_port=12345 --nproc_per_node=4 train_sggn_diffusion.py \\")
    print("    --data_dir /path/to/your/data \\")
    print("    --output_dir ./distributed_training_output \\")
    print("    --config config.json \\")
    print("    --distributed")
    print("")
    print("注意事项:")
    print("1. 确保所有节点都能访问相同的数据目录")
    print("2. 调整batch_size以适应多GPU内存")
    print("3. 使用相同的随机种子确保一致性")
    print("4. 监控所有节点的GPU使用情况")

def example_monitoring_and_visualization():
    """
    训练监控和可视化示例
    """
    print("\n=== 训练监控和可视化示例 ===")
    
    print("1. TensorBoard监控:")
    print("   训练过程中会自动保存TensorBoard日志")
    print("   启动TensorBoard查看训练曲线:")
    print("   tensorboard --logdir ./example_training_output/tensorboard")
    print("")
    
    print("2. 训练历史分析:")
    training_history_path = "./example_training_output/training_history.json"
    if os.path.exists(training_history_path):
        with open(training_history_path, 'r') as f:
            history = json.load(f)
        
        print(f"   训练历史文件: {training_history_path}")
        print(f"   可用指标: {list(history.keys())}")
        
        if 'train_total_loss' in history:
            train_losses = history['train_total_loss']
            print(f"   训练损失趋势: {train_losses[-5:] if len(train_losses) >= 5 else train_losses}")
    else:
        print(f"   训练历史文件不存在: {training_history_path}")
        print("   请先运行训练示例生成历史数据")
    
    print("")
    print("3. 模型检查点分析:")
    checkpoint_dir = Path("./example_training_output")
    if checkpoint_dir.exists():
        checkpoints = list(checkpoint_dir.glob("checkpoint_epoch_*.pth"))
        if checkpoints:
            print(f"   找到 {len(checkpoints)} 个检查点:")
            for cp in sorted(checkpoints):
                print(f"     - {cp.name}")
        else:
            print("   未找到检查点文件")
    else:
        print("   输出目录不存在")

def main():
    """
    主函数 - 运行所有示例
    """
    print("Video2EEG-SGGN-Diffusion 训练示例")
    print("=" * 50)
    
    # 检查CUDA可用性
    if torch.cuda.is_available():
        print(f"CUDA可用，设备数量: {torch.cuda.device_count()}")
        for i in range(torch.cuda.device_count()):
            print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
    else:
        print("CUDA不可用，将使用CPU训练（速度较慢）")
    
    print("\n选择要运行的示例:")
    print("1. 基础训练示例")
    print("2. 自定义配置训练示例")
    print("3. 恢复训练示例")
    print("4. 分布式训练设置示例")
    print("5. 训练监控和可视化示例")
    print("6. 运行所有示例")
    
    try:
        choice = input("\n请输入选择 (1-6): ").strip()
        
        if choice == '1':
            example_basic_training()
        elif choice == '2':
            example_custom_config_training()
        elif choice == '3':
            example_resume_training()
        elif choice == '4':
            example_distributed_training_setup()
        elif choice == '5':
            example_monitoring_and_visualization()
        elif choice == '6':
            example_basic_training()
            example_custom_config_training()
            example_resume_training()
            example_distributed_training_setup()
            example_monitoring_and_visualization()
        else:
            print("无效选择，请输入1-6之间的数字")
            
    except KeyboardInterrupt:
        print("\n用户中断操作")
    except Exception as e:
        print(f"运行示例时发生错误: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()