#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Video2EEG-SGGN-Diffusion 推理示例

本示例展示如何使用训练好的SGGN模型进行推理，包括：
1. 基础推理流程
2. 批量推理处理
3. 质量评估分析
4. 结果可视化
5. 性能基准测试

作者: 算法工程师
日期: 2025年1月12日
"""

import os
import sys
import json
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path

# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

from inference_sggn_diffusion import SGGNInferenceEngine, SGGNModelLoader, EEGQualityEvaluator
from train_sggn_diffusion import SGGNEEGVideoDataset

def example_basic_inference():
    """
    基础推理示例
    """
    print("=== 基础推理示例 ===")
    
    # 模型和数据路径 - 请根据实际情况修改
    model_path = "./example_training_output/best_model.pth"
    data_dir = "/data0/GYF-projects/EEG2Video/Vidoe2EEG/Video2EEG-diffusion/enhanced_processed_data"
    output_dir = "./example_inference_output"
    
    # 检查模型文件是否存在
    if not os.path.exists(model_path):
        print(f"警告: 模型文件不存在 {model_path}")
        print("请先运行训练示例生成模型，或修改model_path为正确的模型路径")
        return
    
    # 检查数据目录是否存在
    if not os.path.exists(data_dir):
        print(f"警告: 数据目录不存在 {data_dir}")
        print("请修改data_dir为正确的数据路径")
        return
    
    try:
        # 创建推理引擎
        inference_engine = SGGNInferenceEngine(
            model_path=model_path,
            data_dir=data_dir,
            output_dir=output_dir,
            device='auto'
        )
        
        print(f"开始基础推理测试...")
        print(f"模型路径: {model_path}")
        print(f"数据目录: {data_dir}")
        print(f"输出目录: {output_dir}")
        
        # 运行推理（少量样本用于快速测试）
        results, analysis = inference_engine.run_complete_evaluation(
            num_samples=10,
            num_inference_steps=20,  # 减少步数加快测试
            guidance_scale=1.0
        )
        
        # 打印关键结果
        print("\n=== 推理结果摘要 ===")
        print(f"测试样本数: {len(results['metrics'])}")
        print(f"平均推理时间: {analysis.get('inference_time_mean', 0):.3f}s")
        print(f"平均MSE: {analysis.get('mse_mean', 0):.6f}")
        print(f"平均相关性: {analysis.get('mean_correlation_mean', 0):.4f}")
        print(f"整体质量评分: {analysis.get('overall_quality_score', 0):.4f}")
        
        print(f"\n详细结果已保存到: {output_dir}")
        print("基础推理示例完成！")
        
    except Exception as e:
        print(f"推理过程中发生错误: {e}")
        import traceback
        traceback.print_exc()

def example_single_sample_inference():
    """
    单样本推理示例
    """
    print("\n=== 单样本推理示例 ===")
    
    model_path = "./example_training_output/best_model.pth"
    data_dir = "/data0/GYF-projects/EEG2Video/Vidoe2EEG/Video2EEG-diffusion/enhanced_processed_data"
    
    if not os.path.exists(model_path):
        print(f"模型文件不存在: {model_path}")
        return
    
    try:
        # 加载模型
        model_loader = SGGNModelLoader(model_path, device='auto')
        
        # 创建测试数据集
        test_dataset = SGGNEEGVideoDataset(
            data_dir, 
            split='test',
            use_graph_da=False
        )
        
        if len(test_dataset) == 0:
            print("测试数据集为空")
            return
        
        # 获取一个样本
        sample = test_dataset[0]
        video_frames = sample['video'].unsqueeze(0)  # 添加batch维度
        reference_eeg = sample['eeg'].numpy()
        
        print(f"输入视频形状: {video_frames.shape}")
        print(f"参考EEG形状: {reference_eeg.shape}")
        
        # 测试不同推理参数
        inference_configs = [
            {'steps': 10, 'guidance': 1.0, 'name': '快速推理'},
            {'steps': 50, 'guidance': 1.0, 'name': '标准推理'},
            {'steps': 50, 'guidance': 1.5, 'name': '高质量推理'}
        ]
        
        results = []
        
        for config in inference_configs:
            print(f"\n运行 {config['name']}...")
            
            start_time = time.time()
            
            # 生成EEG
            generated_eeg = model_loader.generate_eeg(
                video_frames,
                num_inference_steps=config['steps'],
                guidance_scale=config['guidance']
            )
            
            inference_time = time.time() - start_time
            generated_eeg_np = generated_eeg.cpu().numpy()[0]
            
            # 评估质量
            evaluator = EEGQualityEvaluator()
            metrics = evaluator.evaluate_quality(generated_eeg_np, reference_eeg)
            
            result = {
                'config': config,
                'inference_time': inference_time,
                'metrics': metrics,
                'generated_eeg': generated_eeg_np
            }
            results.append(result)
            
            print(f"  推理时间: {inference_time:.3f}s")
            print(f"  MSE: {metrics['mse']:.6f}")
            print(f"  相关性: {metrics['mean_correlation']:.4f}")
        
        # 对比结果
        print("\n=== 推理配置对比 ===")
        print(f"{'配置':<12} {'时间(s)':<8} {'MSE':<12} {'相关性':<8}")
        print("-" * 45)
        
        for result in results:
            config_name = result['config']['name']
            inference_time = result['inference_time']
            mse = result['metrics']['mse']
            corr = result['metrics']['mean_correlation']
            
            print(f"{config_name:<12} {inference_time:<8.3f} {mse:<12.6f} {corr:<8.4f}")
        
        # 可视化对比
        visualize_inference_comparison(reference_eeg, results)
        
        print("单样本推理示例完成！")
        
    except Exception as e:
        print(f"单样本推理过程中发生错误: {e}")
        import traceback
        traceback.print_exc()

def visualize_inference_comparison(reference_eeg, results):
    """
    可视化推理结果对比
    """
    try:
        # 选择几个代表性通道
        channels_to_plot = [0, 15, 30, 45] if reference_eeg.shape[0] > 45 else [0, 1, 2, 3]
        time_axis = np.arange(reference_eeg.shape[1]) / 250.0  # 假设250Hz采样率
        
        fig, axes = plt.subplots(len(channels_to_plot), 1, figsize=(12, 2*len(channels_to_plot)))
        if len(channels_to_plot) == 1:
            axes = [axes]
        
        fig.suptitle('不同推理配置的EEG生成对比', fontsize=14)
        
        colors = ['blue', 'red', 'green', 'orange']
        
        for i, ch in enumerate(channels_to_plot):
            # 绘制参考EEG
            axes[i].plot(time_axis, reference_eeg[ch], 
                        label='真实EEG', color='black', linewidth=2, alpha=0.8)
            
            # 绘制不同配置的生成EEG
            for j, result in enumerate(results):
                generated_eeg = result['generated_eeg']
                config_name = result['config']['name']
                
                axes[i].plot(time_axis, generated_eeg[ch], 
                           label=config_name, color=colors[j % len(colors)], 
                           alpha=0.7, linestyle='--')
            
            axes[i].set_title(f'通道 {ch+1}')
            axes[i].set_xlabel('时间 (s)')
            axes[i].set_ylabel('幅值 (μV)')
            axes[i].legend()
            axes[i].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # 保存图像
        output_path = './example_inference_comparison.png'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"对比图像已保存: {output_path}")
        
    except Exception as e:
        print(f"可视化过程中发生错误: {e}")

def example_batch_inference():
    """
    批量推理示例
    """
    print("\n=== 批量推理示例 ===")
    
    model_path = "./example_training_output/best_model.pth"
    data_dir = "/data0/GYF-projects/EEG2Video/Vidoe2EEG/Video2EEG-diffusion/enhanced_processed_data"
    output_dir = "./example_batch_inference_output"
    
    if not os.path.exists(model_path):
        print(f"模型文件不存在: {model_path}")
        return
    
    try:
        # 创建推理引擎
        inference_engine = SGGNInferenceEngine(
            model_path=model_path,
            data_dir=data_dir,
            output_dir=output_dir,
            device='auto'
        )
        
        print("开始批量推理测试...")
        
        # 运行批量推理
        results = inference_engine.run_inference(
            num_samples=50,  # 中等数量样本
            num_inference_steps=30,
            guidance_scale=1.0
        )
        
        # 分析结果
        analysis = inference_engine.analyze_results(results)
        
        # 打印统计信息
        print("\n=== 批量推理统计 ===")
        print(f"处理样本数: {len(results['metrics'])}")
        print(f"总推理时间: {sum(results['inference_times']):.2f}s")
        print(f"平均推理时间: {analysis.get('inference_time_mean', 0):.3f}s")
        print(f"推理时间标准差: {analysis.get('inference_time_std', 0):.3f}s")
        
        print("\n=== 质量统计 ===")
        print(f"MSE - 均值: {analysis.get('mse_mean', 0):.6f}, 标准差: {analysis.get('mse_std', 0):.6f}")
        print(f"MAE - 均值: {analysis.get('mae_mean', 0):.6f}, 标准差: {analysis.get('mae_std', 0):.6f}")
        print(f"相关性 - 均值: {analysis.get('mean_correlation_mean', 0):.4f}, 标准差: {analysis.get('mean_correlation_std', 0):.4f}")
        
        # 频段相似性
        print("\n=== 频段相似性 ===")
        for band in ['delta', 'theta', 'alpha', 'beta', 'gamma']:
            similarity_key = f'{band}_similarity_mean'
            if similarity_key in analysis:
                print(f"{band.capitalize()}: {analysis[similarity_key]:.4f}")
        
        # 生成简单可视化
        generate_batch_analysis_plots(results, analysis, output_dir)
        
        print(f"\n批量推理结果已保存到: {output_dir}")
        print("批量推理示例完成！")
        
    except Exception as e:
        print(f"批量推理过程中发生错误: {e}")
        import traceback
        traceback.print_exc()

def generate_batch_analysis_plots(results, analysis, output_dir):
    """
    生成批量分析图表
    """
    try:
        os.makedirs(output_dir, exist_ok=True)
        
        # 1. 推理时间分布
        plt.figure(figsize=(10, 6))
        
        plt.subplot(2, 2, 1)
        plt.hist(results['inference_times'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        plt.title('Inference Time Distribution')
        plt.xlabel('Time (s)')
        plt.ylabel('Frequency')
        plt.grid(True, alpha=0.3)
        
        # 2. MSE分布
        plt.subplot(2, 2, 2)
        mse_values = [m['mse'] for m in results['metrics']]
        plt.hist(mse_values, bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
        plt.title('MSE Distribution')
        plt.xlabel('MSE')
        plt.ylabel('Frequency')
        plt.grid(True, alpha=0.3)
        
        # 3. 相关性分布
        plt.subplot(2, 2, 3)
        corr_values = [m['mean_correlation'] for m in results['metrics'] 
                      if not np.isnan(m['mean_correlation'])]
        plt.hist(corr_values, bins=20, alpha=0.7, color='lightgreen', edgecolor='black')
        plt.title('Correlation Distribution')
        plt.xlabel('Correlation Coefficient')
        plt.ylabel('Frequency')
        plt.grid(True, alpha=0.3)
        
        # 4. 质量随时间变化
        plt.subplot(2, 2, 4)
        plt.plot(mse_values, 'o-', alpha=0.7, label='MSE')
        plt.plot([c*10 for c in corr_values], 's-', alpha=0.7, label='相关性×10')
        plt.title('Quality Metrics vs Sample Index')
        plt.xlabel('Sample Index')
        plt.ylabel('Metric Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        # 保存图像
        plot_path = os.path.join(output_dir, 'batch_analysis.png')
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"批量分析图表已保存: {plot_path}")
        
    except Exception as e:
        print(f"生成批量分析图表时发生错误: {e}")

def example_performance_benchmark():
    """
    性能基准测试示例
    """
    print("\n=== 性能基准测试示例 ===")
    
    model_path = "./example_training_output/best_model.pth"
    
    if not os.path.exists(model_path):
        print(f"模型文件不存在: {model_path}")
        return
    
    try:
        # 加载模型
        model_loader = SGGNModelLoader(model_path, device='auto')
        
        # 创建不同大小的测试数据
        test_configs = [
            {'batch_size': 1, 'name': '单样本'},
            {'batch_size': 2, 'name': '小批次'},
            {'batch_size': 4, 'name': '中批次'}
        ]
        
        inference_steps_list = [10, 20, 50]
        
        print("开始性能基准测试...")
        print(f"{'配置':<10} {'批次大小':<8} {'推理步数':<8} {'时间(s)':<10} {'内存(MB)':<10}")
        print("-" * 60)
        
        for config in test_configs:
            batch_size = config['batch_size']
            config_name = config['name']
            
            for steps in inference_steps_list:
                # 创建测试数据
                video_frames = torch.randn(batch_size, 200, 3, 224, 224)
                
                # 清理GPU内存
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.reset_peak_memory_stats()
                
                # 预热
                with torch.no_grad():
                    _ = model_loader.generate_eeg(
                        video_frames[:1], 
                        num_inference_steps=5,
                        guidance_scale=1.0
                    )
                
                # 基准测试
                start_time = time.time()
                
                with torch.no_grad():
                    generated_eeg = model_loader.generate_eeg(
                        video_frames,
                        num_inference_steps=steps,
                        guidance_scale=1.0
                    )
                
                end_time = time.time()
                inference_time = end_time - start_time
                
                # 内存使用
                if torch.cuda.is_available():
                    memory_used = torch.cuda.max_memory_allocated() / 1024 / 1024  # MB
                else:
                    memory_used = 0
                
                print(f"{config_name:<10} {batch_size:<8} {steps:<8} {inference_time:<10.3f} {memory_used:<10.1f}")
        
        print("\n性能基准测试完成！")
        
        # 设备信息
        print("\n=== 设备信息 ===")
        if torch.cuda.is_available():
            print(f"GPU: {torch.cuda.get_device_name()}")
            print(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        else:
            print("使用CPU进行推理")
        
    except Exception as e:
        print(f"性能基准测试过程中发生错误: {e}")
        import traceback
        traceback.print_exc()

def example_quality_analysis():
    """
    质量分析示例
    """
    print("\n=== 质量分析示例 ===")
    
    # 检查是否有推理结果
    results_path = "./example_inference_output/inference_results.json"
    
    if not os.path.exists(results_path):
        print(f"推理结果文件不存在: {results_path}")
        print("请先运行基础推理示例生成结果")
        return
    
    try:
        # 加载推理结果
        with open(results_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        metrics = data['metrics']
        analysis = data['analysis']
        
        print(f"加载了 {len(metrics)} 个样本的推理结果")
        
        # 详细质量分析
        print("\n=== 详细质量分析 ===")
        
        # 基础指标
        print("基础指标:")
        for metric in ['mse', 'mae', 'mean_correlation']:
            mean_key = f'{metric}_mean'
            std_key = f'{metric}_std'
            if mean_key in analysis:
                print(f"  {metric.upper()}: {analysis[mean_key]:.6f} ± {analysis.get(std_key, 0):.6f}")
        
        # 频段相似性
        print("\n频段相似性:")
        for band in ['delta', 'theta', 'alpha', 'beta', 'gamma']:
            similarity_key = f'{band}_similarity_mean'
            if similarity_key in analysis:
                print(f"  {band.capitalize()}: {analysis[similarity_key]:.4f}")
        
        # 时域特征相似性
        print("\n时域特征相似性:")
        temporal_features = ['mean', 'std', 'var', 'skewness', 'kurtosis', 'energy', 'rms']
        for feature in temporal_features:
            similarity_key = f'{feature}_similarity_mean'
            if similarity_key in analysis:
                print(f"  {feature.capitalize()}: {analysis[similarity_key]:.4f}")
        
        # 空间特征误差
        print("\n空间特征相对误差:")
        spatial_features = ['mean_correlation', 'std_correlation', 'mean_gfp', 'std_gfp']
        for feature in spatial_features:
            error_key = f'{feature}_error_mean'
            if error_key in analysis:
                print(f"  {feature}: {analysis[error_key]:.4f}")
        
        # 质量等级评估
        overall_score = analysis.get('overall_quality_score', 0)
        print(f"\n整体质量评分: {overall_score:.4f}")
        
        if overall_score >= 0.8:
            quality_level = "优秀"
        elif overall_score >= 0.6:
            quality_level = "良好"
        elif overall_score >= 0.4:
            quality_level = "一般"
        else:
            quality_level = "需要改进"
        
        print(f"质量等级: {quality_level}")
        
        # 改进建议
        print("\n=== 改进建议 ===")
        
        avg_mse = analysis.get('mse_mean', 0)
        avg_corr = analysis.get('mean_correlation_mean', 0)
        
        if avg_mse > 0.01:
            print("- MSE较高，建议增加训练轮数或调整学习率")
        
        if avg_corr < 0.5:
            print("- 相关性较低，建议增强数据增强或调整模型架构")
        
        if overall_score < 0.6:
            print("- 整体质量需要提升，建议检查数据质量和模型配置")
        
        print("质量分析示例完成！")
        
    except Exception as e:
        print(f"质量分析过程中发生错误: {e}")
        import traceback
        traceback.print_exc()

def main():
    """
    主函数 - 运行所有推理示例
    """
    print("Video2EEG-SGGN-Diffusion 推理示例")
    print("=" * 50)
    
    # 检查CUDA可用性
    if torch.cuda.is_available():
        print(f"CUDA可用，设备: {torch.cuda.get_device_name()}")
    else:
        print("CUDA不可用，将使用CPU推理")
    
    print("\n选择要运行的示例:")
    print("1. 基础推理示例")
    print("2. 单样本推理示例")
    print("3. 批量推理示例")
    print("4. 性能基准测试示例")
    print("5. 质量分析示例")
    print("6. 运行所有示例")
    
    try:
        choice = input("\n请输入选择 (1-6): ").strip()
        
        if choice == '1':
            example_basic_inference()
        elif choice == '2':
            example_single_sample_inference()
        elif choice == '3':
            example_batch_inference()
        elif choice == '4':
            example_performance_benchmark()
        elif choice == '5':
            example_quality_analysis()
        elif choice == '6':
            example_basic_inference()
            example_single_sample_inference()
            example_batch_inference()
            example_performance_benchmark()
            example_quality_analysis()
        else:
            print("无效选择，请输入1-6之间的数字")
            
    except KeyboardInterrupt:
        print("\n用户中断操作")
    except Exception as e:
        print(f"运行示例时发生错误: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()