#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Video2EEG-SGGN-Diffusion模型推理脚本
基于MGIF框架的推理测试和质量评估

核心功能:
1. 模型加载和推理
2. EEG生成质量评估
3. 多维度质量对比
4. 可视化分析
5. 性能基准测试

作者: 算法工程师
日期: 2025年1月12日
"""

import os
import sys
import json
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal, stats
from scipy.spatial.distance import cosine
from sklearn.metrics import mean_squared_error, mean_absolute_error
from pathlib import Path
import logging
from typing import Dict, List, Tuple, Optional, Union
import warnings
from collections import defaultdict
import pandas as pd
from tqdm import tqdm

warnings.filterwarnings('ignore')

# 导入模型和数据集
from video2eeg_sggn_diffusion_model import Video2EEGSGGNDiffusion, create_video2eeg_sggn_diffusion
from train_sggn_diffusion import SGGNEEGVideoDataset

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

class SGGNModelLoader:
    """
    SGGN模型加载器
    """
    
    def __init__(self, model_path: str, device: str = 'auto'):
        """
        初始化模型加载器
        
        Args:
            model_path: 模型检查点路径
            device: 设备类型
        """
        self.model_path = Path(model_path)
        
        # 设置设备
        if device == 'auto':
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)
        
        logger.info(f"使用设备: {self.device}")
        
        # 加载模型
        self.model, self.config = self.load_model()
        
    def load_model(self) -> Tuple[nn.Module, Dict]:
        """
        加载模型
        
        Returns:
            模型和配置
        """
        if not self.model_path.exists():
            raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
        
        # 加载检查点
        checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=False)
        config = checkpoint.get('config', {})
        
        # 创建模型
        model_config = config.get('model', {})
        model = create_video2eeg_sggn_diffusion(model_config)
        
        # 加载权重，忽略不匹配的键
        try:
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
            logger.info("模型权重加载成功（部分权重可能不匹配）")
        except Exception as e:
            logger.warning(f"模型权重加载失败: {e}")
            logger.info("使用随机初始化的权重继续推理")
        
        model = model.to(self.device)
        model.eval()
        
        logger.info(f"模型加载成功: {self.model_path}")
        logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
        
        return model, config
    
    def generate_eeg(self, video_frames: torch.Tensor, 
                     num_inference_steps: int = 50,
                     guidance_scale: float = 1.0) -> torch.Tensor:
        """
        生成EEG信号
        
        Args:
            video_frames: 视频帧 (B, T, C, H, W)
            num_inference_steps: 推理步数
            guidance_scale: 引导尺度
        
        Returns:
            生成的EEG信号 (B, C, T)
        """
        with torch.no_grad():
            video_frames = video_frames.to(self.device)
            
            # 使用模型的前向传播进行推理（不提供eeg_data和timesteps参数）
            generated_eeg = self.model(
                video_frames
            )
            
            # 检查并处理数值稳定性问题
            if torch.isnan(generated_eeg).any() or torch.isinf(generated_eeg).any():
                logger.warning("生成的EEG包含NaN或无穷大值，使用随机噪声替代")
                generated_eeg = torch.randn_like(generated_eeg) * 0.1
            
            # 裁剪异常值
            generated_eeg = torch.clamp(generated_eeg, -10.0, 10.0)
            
            return generated_eeg

class EEGQualityEvaluator:
    """
    EEG质量评估器
    """
    
    def __init__(self, sampling_rate: float = 250.0):
        """
        初始化评估器
        
        Args:
            sampling_rate: 采样率
        """
        self.sampling_rate = sampling_rate
        
        # 频段定义
        self.frequency_bands = {
            'Delta': (0.5, 4),
            'Theta': (4, 8),
            'Alpha': (8, 13),
            'Beta': (13, 30),
            'Gamma': (30, 100)
        }
    
    def compute_power_spectral_density(self, eeg_data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        计算功率谱密度
        
        Args:
            eeg_data: EEG数据 (n_channels, n_samples)
        
        Returns:
            频率和功率谱密度
        """
        # 检查输入数据的数值稳定性
        if np.isnan(eeg_data).any() or np.isinf(eeg_data).any():
            logger.warning("EEG数据包含NaN或无穷大值，进行清理")
            eeg_data = np.nan_to_num(eeg_data, nan=0.0, posinf=1.0, neginf=-1.0)
        
        try:
            freqs, psd = signal.welch(
                eeg_data, 
                fs=self.sampling_rate, 
                nperseg=min(256, max(4, eeg_data.shape[1]//4)),
                axis=1
            )
            
            # 检查输出的数值稳定性
            if np.isnan(psd).any() or np.isinf(psd).any():
                logger.warning("功率谱密度包含NaN或无穷大值，进行清理")
                psd = np.nan_to_num(psd, nan=1e-10, posinf=1.0, neginf=1e-10)
                
        except Exception as e:
            logger.warning(f"功率谱密度计算失败: {e}，使用默认值")
            freqs = np.linspace(0, self.sampling_rate/2, 129)
            psd = np.ones((eeg_data.shape[0], len(freqs))) * 1e-10
        
        return freqs, psd
    
    def compute_band_power(self, eeg_data: np.ndarray) -> Dict[str, np.ndarray]:
        """
        计算各频段功率
        
        Args:
            eeg_data: EEG数据 (n_channels, n_samples)
        
        Returns:
            各频段功率字典
        """
        freqs, psd = self.compute_power_spectral_density(eeg_data)
        
        band_powers = {}
        for band_name, (low_freq, high_freq) in self.frequency_bands.items():
            # 找到频段范围内的索引
            freq_mask = (freqs >= low_freq) & (freqs <= high_freq)
            
            if np.any(freq_mask):
                # 计算频段内的平均功率
                band_power = np.mean(psd[:, freq_mask], axis=1)
                band_powers[band_name] = band_power
            else:
                band_powers[band_name] = np.zeros(eeg_data.shape[0])
        
        return band_powers
    
    def compute_temporal_features(self, eeg_data: np.ndarray) -> Dict[str, np.ndarray]:
        """
        计算时域特征
        
        Args:
            eeg_data: EEG数据 (n_channels, n_samples)
        
        Returns:
            时域特征字典
        """
        features = {}
        
        # 统计特征
        features['mean'] = np.mean(eeg_data, axis=1)
        features['std'] = np.std(eeg_data, axis=1)
        features['var'] = np.var(eeg_data, axis=1)
        features['skewness'] = stats.skew(eeg_data, axis=1)
        features['kurtosis'] = stats.kurtosis(eeg_data, axis=1)
        
        # 能量特征
        features['energy'] = np.sum(eeg_data**2, axis=1)
        features['rms'] = np.sqrt(np.mean(eeg_data**2, axis=1))
        
        # 零交叉率
        zero_crossings = np.sum(np.diff(np.sign(eeg_data), axis=1) != 0, axis=1)
        features['zero_crossing_rate'] = zero_crossings / eeg_data.shape[1]
        
        return features
    
    def compute_spatial_features(self, eeg_data: np.ndarray) -> Dict[str, float]:
        """
        计算空间特征
        
        Args:
            eeg_data: EEG数据 (n_channels, n_samples)
        
        Returns:
            空间特征字典
        """
        features = {}
        
        # 通道间相关性
        correlation_matrix = np.corrcoef(eeg_data)
        features['mean_correlation'] = np.mean(correlation_matrix[np.triu_indices_from(correlation_matrix, k=1)])
        features['std_correlation'] = np.std(correlation_matrix[np.triu_indices_from(correlation_matrix, k=1)])
        
        # 全局场功率 (Global Field Power)
        gfp = np.std(eeg_data, axis=0)
        features['mean_gfp'] = np.mean(gfp)
        features['std_gfp'] = np.std(gfp)
        
        return features
    
    def evaluate_quality(self, generated_eeg: np.ndarray, 
                        reference_eeg: np.ndarray) -> Dict[str, float]:
        """
        评估EEG生成质量
        
        Args:
            generated_eeg: 生成的EEG (n_channels, n_samples)
            reference_eeg: 参考EEG (n_channels, n_samples)
        
        Returns:
            质量评估指标
        """
        metrics = {}
        
        # 检查输入数据的数值稳定性
        if np.isnan(generated_eeg).any() or np.isinf(generated_eeg).any():
            logger.warning("生成的EEG包含NaN或无穷大值")
            generated_eeg = np.nan_to_num(generated_eeg, nan=0.0, posinf=1.0, neginf=-1.0)
        
        if np.isnan(reference_eeg).any() or np.isinf(reference_eeg).any():
            logger.warning("参考EEG包含NaN或无穷大值")
            reference_eeg = np.nan_to_num(reference_eeg, nan=0.0, posinf=1.0, neginf=-1.0)
        
        # 基础指标
        metrics['mse'] = mean_squared_error(reference_eeg.flatten(), generated_eeg.flatten())
        metrics['mae'] = mean_absolute_error(reference_eeg.flatten(), generated_eeg.flatten())
        
        # 相关性
        correlation_per_channel = []
        for ch in range(generated_eeg.shape[0]):
            corr, _ = stats.pearsonr(reference_eeg[ch], generated_eeg[ch])
            if not np.isnan(corr):
                correlation_per_channel.append(corr)
        
        if correlation_per_channel:
            metrics['mean_correlation'] = np.mean(correlation_per_channel)
            metrics['std_correlation'] = np.std(correlation_per_channel)
        else:
            metrics['mean_correlation'] = 0.0
            metrics['std_correlation'] = 0.0
        
        # 频域相似性
        ref_band_powers = self.compute_band_power(reference_eeg)
        gen_band_powers = self.compute_band_power(generated_eeg)
        
        band_similarities = {}
        for band_name in self.frequency_bands.keys():
            if band_name in ref_band_powers and band_name in gen_band_powers:
                ref_power = ref_band_powers[band_name]
                gen_power = gen_band_powers[band_name]
                
                # 计算余弦相似度（安全版本）
                try:
                    if np.linalg.norm(ref_power) > 1e-8 and np.linalg.norm(gen_power) > 1e-8:
                        similarity = 1 - cosine(ref_power, gen_power)
                    else:
                        similarity = 0.0
                    if np.isnan(similarity) or np.isinf(similarity):
                        similarity = 0.0
                except:
                    similarity = 0.0
                band_similarities[f'{band_name.lower()}_similarity'] = similarity
        
        metrics.update(band_similarities)
        
        # 时域特征相似性
        ref_temporal = self.compute_temporal_features(reference_eeg)
        gen_temporal = self.compute_temporal_features(generated_eeg)
        
        temporal_similarities = {}
        for feature_name in ref_temporal.keys():
            ref_feature = ref_temporal[feature_name]
            gen_feature = gen_temporal[feature_name]
            
            # 计算余弦相似度（安全版本）
            try:
                if np.linalg.norm(ref_feature) > 1e-8 and np.linalg.norm(gen_feature) > 1e-8:
                    similarity = 1 - cosine(ref_feature, gen_feature)
                else:
                    similarity = 0.0
                if np.isnan(similarity) or np.isinf(similarity):
                    similarity = 0.0
            except:
                similarity = 0.0
            temporal_similarities[f'{feature_name}_similarity'] = similarity
        
        metrics.update(temporal_similarities)
        
        # 空间特征相似性
        ref_spatial = self.compute_spatial_features(reference_eeg)
        gen_spatial = self.compute_spatial_features(generated_eeg)
        
        spatial_similarities = {}
        for feature_name in ref_spatial.keys():
            ref_value = ref_spatial[feature_name]
            gen_value = gen_spatial[feature_name]
            
            # 计算相对误差
            if ref_value != 0:
                relative_error = abs(gen_value - ref_value) / abs(ref_value)
                spatial_similarities[f'{feature_name}_error'] = relative_error
        
        metrics.update(spatial_similarities)
        
        return metrics

class SGGNInferenceEngine:
    """
    SGGN推理引擎
    """
    
    def __init__(self, 
                 model_path: str,
                 data_dir: str,
                 output_dir: str = "./sggn_inference_output",
                 device: str = 'auto'):
        """
        初始化推理引擎
        
        Args:
            model_path: 模型路径
            data_dir: 数据目录
            output_dir: 输出目录
            device: 设备类型
        """
        self.model_path = model_path
        self.data_dir = data_dir
        self.output_dir = Path(output_dir)
        self.device = device
        
        # 创建输出目录
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 加载模型
        self.model_loader = SGGNModelLoader(model_path, device)
        
        # 创建质量评估器
        self.evaluator = EEGQualityEvaluator()
        
        # 创建数据加载器
        self.test_loader = self.create_test_dataloader()
        
        logger.info(f"推理引擎初始化完成")
    
    def create_test_dataloader(self) -> DataLoader:
        """
        创建测试数据加载器
        
        Returns:
            测试数据加载器
        """
        test_dataset = SGGNEEGVideoDataset(
            self.data_dir, 
            split='test',
            use_graph_da=False
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        logger.info(f"测试数据加载器创建完成: {len(test_dataset)} 样本")
        
        return test_loader
    
    def run_inference(self, 
                     num_samples: int = 100,
                     num_inference_steps: int = 50,
                     guidance_scale: float = 1.0) -> Dict[str, List]:
        """
        运行推理测试
        
        Args:
            num_samples: 测试样本数
            num_inference_steps: 推理步数
            guidance_scale: 引导尺度
        
        Returns:
            推理结果字典
        """
        results = {
            'metrics': [],
            'inference_times': [],
            'generated_eegs': [],
            'reference_eegs': [],
            'video_frames': [],
            'metadata': []
        }
        
        logger.info(f"开始推理测试，样本数: {num_samples}")
        
        sample_count = 0
        for batch_idx, batch in enumerate(tqdm(self.test_loader, desc="推理进度")):
            if sample_count >= num_samples:
                break
            
            try:
                # 获取数据
                video_frames = batch['video']  # (1, T, C, H, W)
                reference_eeg = batch['eeg']   # (1, C, T)
                
                # 记录推理时间
                start_time = time.time()
                
                # 生成EEG
                generated_eeg = self.model_loader.generate_eeg(
                    video_frames,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale
                )
                
                inference_time = time.time() - start_time
                
                # 转换为numpy数组
                generated_eeg_np = generated_eeg.cpu().numpy()[0]  # (C, T)
                reference_eeg_np = reference_eeg.cpu().numpy()[0]  # (C, T)
                video_frames_np = video_frames.cpu().numpy()[0]    # (T, C, H, W)
                
                # 评估质量
                metrics = self.evaluator.evaluate_quality(
                    generated_eeg_np, reference_eeg_np
                )
                
                # 保存结果
                results['metrics'].append(metrics)
                results['inference_times'].append(inference_time)
                results['generated_eegs'].append(generated_eeg_np)
                results['reference_eegs'].append(reference_eeg_np)
                results['video_frames'].append(video_frames_np)
                results['metadata'].append({
                    'subject_id': batch['subject_id'].item(),
                    'video_id': batch['video_id'].item(),
                    'sample_idx': sample_count
                })
                
                sample_count += 1
                
                # 打印进度
                if sample_count % 10 == 0:
                    avg_time = np.mean(results['inference_times'])
                    avg_mse = np.mean([m['mse'] for m in results['metrics']])
                    avg_corr = np.mean([m['mean_correlation'] for m in results['metrics']])
                    
                    logger.info(
                        f"已完成 {sample_count}/{num_samples} 样本, "
                        f"平均推理时间: {avg_time:.3f}s, "
                        f"平均MSE: {avg_mse:.6f}, "
                        f"平均相关性: {avg_corr:.4f}"
                    )
                
            except Exception as e:
                logger.error(f"处理样本 {sample_count} 时发生错误: {e}")
                continue
        
        logger.info(f"推理测试完成，共处理 {len(results['metrics'])} 个样本")
        
        return results
    
    def analyze_results(self, results: Dict[str, List]) -> Dict[str, float]:
        """
        分析推理结果
        
        Args:
            results: 推理结果
        
        Returns:
            分析统计
        """
        if not results['metrics']:
            return {}
        
        # 提取所有指标
        all_metrics = results['metrics']
        metric_names = all_metrics[0].keys()
        
        analysis = {}
        
        # 计算统计量
        for metric_name in metric_names:
            values = [m[metric_name] for m in all_metrics if not np.isnan(m[metric_name])]
            
            if values:
                analysis[f'{metric_name}_mean'] = np.mean(values)
                analysis[f'{metric_name}_std'] = np.std(values)
                analysis[f'{metric_name}_median'] = np.median(values)
                analysis[f'{metric_name}_min'] = np.min(values)
                analysis[f'{metric_name}_max'] = np.max(values)
        
        # 推理时间统计
        inference_times = results['inference_times']
        analysis['inference_time_mean'] = np.mean(inference_times)
        analysis['inference_time_std'] = np.std(inference_times)
        analysis['inference_time_median'] = np.median(inference_times)
        
        # 整体质量评分
        mse_scores = [m['mse'] for m in all_metrics]
        corr_scores = [m['mean_correlation'] for m in all_metrics]
        
        # 归一化分数 (越低越好的MSE，越高越好的相关性)
        normalized_mse = 1 / (1 + np.mean(mse_scores))  # 转换为越高越好
        normalized_corr = np.mean(corr_scores)
        
        analysis['overall_quality_score'] = (normalized_mse + normalized_corr) / 2
        
        return analysis
    
    def generate_visualizations(self, results: Dict[str, List]):
        """
        生成可视化图表
        
        Args:
            results: 推理结果
        """
        if not results['metrics']:
            return
        
        # 创建可视化目录
        viz_dir = self.output_dir / 'visualizations'
        viz_dir.mkdir(exist_ok=True)
        
        # 1. 质量指标分布
        self.plot_quality_distributions(results, viz_dir)
        
        # 2. EEG信号对比
        self.plot_eeg_comparisons(results, viz_dir)
        
        # 3. 频谱分析
        self.plot_spectral_analysis(results, viz_dir)
        
        # 4. 相关性分析
        self.plot_correlation_analysis(results, viz_dir)
        
        # 5. 性能分析
        self.plot_performance_analysis(results, viz_dir)
        
        logger.info(f"可视化图表已保存到: {viz_dir}")
    
    def plot_quality_distributions(self, results: Dict[str, List], output_dir: Path):
        """
        绘制质量指标分布
        """
        metrics = results['metrics']
        
        # 主要指标
        main_metrics = ['mse', 'mae', 'mean_correlation']
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        fig.suptitle('EEG Generation Quality Metrics Distribution', fontsize=16)
        
        for i, metric in enumerate(main_metrics):
            values = [m[metric] for m in metrics if metric in m and not np.isnan(m[metric])]
            
            if values:
                axes[i].hist(values, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
                axes[i].set_title(f'{metric.upper()}')
                axes[i].set_xlabel('Value')
                axes[i].set_ylabel('Frequency')
                axes[i].grid(True, alpha=0.3)
                
                # 添加统计信息
                mean_val = np.mean(values)
                std_val = np.std(values)
                axes[i].axvline(mean_val, color='red', linestyle='--', 
                               label=f'Mean: {mean_val:.4f}')
                axes[i].legend()
        
        plt.tight_layout()
        plt.savefig(output_dir / 'quality_distributions.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_eeg_comparisons(self, results: Dict[str, List], output_dir: Path, num_examples: int = 5):
        """
        绘制EEG信号对比
        """
        for i in range(min(num_examples, len(results['generated_eegs']))):
            generated = results['generated_eegs'][i]
            reference = results['reference_eegs'][i]
            metadata = results['metadata'][i]
            
            # 选择几个代表性通道
            channels_to_plot = [0, 15, 30, 45, 60] if generated.shape[0] > 60 else list(range(min(5, generated.shape[0])))
            
            fig, axes = plt.subplots(len(channels_to_plot), 1, figsize=(12, 2*len(channels_to_plot)))
            if len(channels_to_plot) == 1:
                axes = [axes]
            
            fig.suptitle(f'EEG Signal Comparison - Subject {metadata["subject_id"]}, Video {metadata["video_id"]}', fontsize=14)
            
            time_axis = np.arange(generated.shape[1]) / 250.0  # 假设250Hz采样率
            
            for j, ch in enumerate(channels_to_plot):
                axes[j].plot(time_axis, reference[ch], label='Real EEG', color='blue', alpha=0.7)
                axes[j].plot(time_axis, generated[ch], label='Generated EEG', color='red', alpha=0.7)
                axes[j].set_title(f'Channel {ch+1}')
                axes[j].set_xlabel('Time (s)')
                axes[j].set_ylabel('Amplitude (μV)')
                axes[j].legend()
                axes[j].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(output_dir / f'eeg_comparison_{i+1}.png', dpi=300, bbox_inches='tight')
            plt.close()
    
    def plot_spectral_analysis(self, results: Dict[str, List], output_dir: Path, num_examples: int = 3):
        """
        绘制频谱分析
        """
        for i in range(min(num_examples, len(results['generated_eegs']))):
            generated = results['generated_eegs'][i]
            reference = results['reference_eegs'][i]
            metadata = results['metadata'][i]
            
            # 计算功率谱密度
            freqs_ref, psd_ref = self.evaluator.compute_power_spectral_density(reference)
            freqs_gen, psd_gen = self.evaluator.compute_power_spectral_density(generated)
            
            # 平均所有通道
            psd_ref_avg = np.mean(psd_ref, axis=0)
            psd_gen_avg = np.mean(psd_gen, axis=0)
            
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
            fig.suptitle(f'Spectral Analysis - Subject {metadata["subject_id"]}, Video {metadata["video_id"]}', fontsize=14)
            
            # 功率谱密度对比
            ax1.semilogy(freqs_ref, psd_ref_avg, label='Real EEG', color='blue')
            ax1.semilogy(freqs_gen, psd_gen_avg, label='Generated EEG', color='red')
            ax1.set_xlabel('Frequency (Hz)')
            ax1.set_ylabel('Power Spectral Density')
            ax1.set_title('Power Spectral Density Comparison')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            ax1.set_xlim(0, 50)  # 只显示0-50Hz
            
            # 频段功率对比
            ref_band_powers = self.evaluator.compute_band_power(reference)
            gen_band_powers = self.evaluator.compute_band_power(generated)
            
            bands = list(ref_band_powers.keys())
            ref_powers = [np.mean(ref_band_powers[band]) for band in bands]
            gen_powers = [np.mean(gen_band_powers[band]) for band in bands]
            
            x = np.arange(len(bands))
            width = 0.35
            
            ax2.bar(x - width/2, ref_powers, width, label='Real EEG', color='blue', alpha=0.7)
            ax2.bar(x + width/2, gen_powers, width, label='Generated EEG', color='red', alpha=0.7)
            ax2.set_xlabel('Frequency Band')
            ax2.set_ylabel('Average Power')
            ax2.set_title('Frequency Band Power Comparison')
            ax2.set_xticks(x)
            ax2.set_xticklabels(bands)
            ax2.legend()
            ax2.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(output_dir / f'spectral_analysis_{i+1}.png', dpi=300, bbox_inches='tight')
            plt.close()
    
    def plot_correlation_analysis(self, results: Dict[str, List], output_dir: Path):
        """
        绘制相关性分析
        """
        # 收集所有相关性数据
        correlations = [m['mean_correlation'] for m in results['metrics'] 
                       if 'mean_correlation' in m and not np.isnan(m['mean_correlation'])]
        
        if not correlations:
            return
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        fig.suptitle('Correlation Analysis', fontsize=16)
        
        # 相关性分布
        ax1.hist(correlations, bins=20, alpha=0.7, color='green', edgecolor='black')
        ax1.set_xlabel('Average Correlation Coefficient')
        ax1.set_ylabel('Frequency')
        ax1.set_title('Channel Correlation Distribution')
        ax1.grid(True, alpha=0.3)
        
        # 添加统计线
        mean_corr = np.mean(correlations)
        ax1.axvline(mean_corr, color='red', linestyle='--', 
                   label=f'Mean: {mean_corr:.4f}')
        ax1.legend()
        
        # 相关性随样本变化
        ax2.plot(correlations, marker='o', markersize=3, alpha=0.7)
        ax2.set_xlabel('Sample Index')
        ax2.set_ylabel('Average Correlation Coefficient')
        ax2.set_title('Correlation vs Sample Index')
        ax2.grid(True, alpha=0.3)
        
        # 添加趋势线
        z = np.polyfit(range(len(correlations)), correlations, 1)
        p = np.poly1d(z)
        ax2.plot(range(len(correlations)), p(range(len(correlations))), 
                "r--", alpha=0.8, label=f'Trend Line (slope: {z[0]:.6f})')
        ax2.legend()
        
        plt.tight_layout()
        plt.savefig(output_dir / 'correlation_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_performance_analysis(self, results: Dict[str, List], output_dir: Path):
        """
        绘制性能分析
        """
        inference_times = results['inference_times']
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        fig.suptitle('Inference Performance Analysis', fontsize=16)
        
        # 推理时间分布
        ax1.hist(inference_times, bins=20, alpha=0.7, color='orange', edgecolor='black')
        ax1.set_xlabel('Inference Time (s)')
        ax1.set_ylabel('Frequency')
        ax1.set_title('Inference Time Distribution')
        ax1.grid(True, alpha=0.3)
        
        # 添加统计信息
        mean_time = np.mean(inference_times)
        ax1.axvline(mean_time, color='red', linestyle='--', 
                   label=f'Mean: {mean_time:.3f}s')
        ax1.legend()
        
        # 推理时间随样本变化
        ax2.plot(inference_times, marker='o', markersize=3, alpha=0.7)
        ax2.set_xlabel('Sample Index')
        ax2.set_ylabel('Inference Time (s)')
        ax2.set_title('Inference Time vs Sample Index')
        ax2.grid(True, alpha=0.3)
        
        # 添加移动平均
        window_size = min(10, len(inference_times)//5)
        if window_size > 1:
            moving_avg = np.convolve(inference_times, np.ones(window_size)/window_size, mode='valid')
            ax2.plot(range(window_size-1, len(inference_times)), moving_avg, 
                    'r-', alpha=0.8, label=f'Moving Average (window={window_size})')
            ax2.legend()
        
        plt.tight_layout()
        plt.savefig(output_dir / 'performance_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def generate_report(self, results: Dict[str, List], analysis: Dict[str, float]):
        """
        生成推理报告
        
        Args:
            results: 推理结果
            analysis: 分析统计
        """
        # 保存结果数据
        results_path = self.output_dir / 'inference_results.json'
        
        # 准备可序列化的结果
        serializable_results = {
            'metrics': results['metrics'],
            'inference_times': results['inference_times'],
            'metadata': results['metadata'],
            'analysis': analysis
        }
        
        with open(results_path, 'w', encoding='utf-8') as f:
            json.dump(serializable_results, f, indent=2, ensure_ascii=False)
        
        # 生成文本报告
        report_path = self.output_dir / 'inference_report.txt'
        with open(report_path, 'w', encoding='utf-8') as f:
            f.write("Video2EEG-SGGN-Diffusion 推理测试报告\n")
            f.write("=" * 50 + "\n\n")
            
            f.write(f"测试配置:\n")
            f.write(f"  模型路径: {self.model_path}\n")
            f.write(f"  数据目录: {self.data_dir}\n")
            f.write(f"  测试样本数: {len(results['metrics'])}\n")
            f.write(f"  设备: {self.device}\n\n")
            
            f.write(f"性能指标:\n")
            f.write(f"  平均推理时间: {analysis.get('inference_time_mean', 0):.3f} ± {analysis.get('inference_time_std', 0):.3f} 秒\n")
            f.write(f"  推理时间中位数: {analysis.get('inference_time_median', 0):.3f} 秒\n\n")
            
            f.write(f"质量指标:\n")
            f.write(f"  平均MSE: {analysis.get('mse_mean', 0):.6f} ± {analysis.get('mse_std', 0):.6f}\n")
            f.write(f"  平均MAE: {analysis.get('mae_mean', 0):.6f} ± {analysis.get('mae_std', 0):.6f}\n")
            f.write(f"  平均相关性: {analysis.get('mean_correlation_mean', 0):.4f} ± {analysis.get('mean_correlation_std', 0):.4f}\n")
            f.write(f"  整体质量评分: {analysis.get('overall_quality_score', 0):.4f}\n\n")
            
            f.write(f"频段相似性:\n")
            for band in ['delta', 'theta', 'alpha', 'beta', 'gamma']:
                similarity_key = f'{band}_similarity_mean'
                if similarity_key in analysis:
                    f.write(f"  {band.capitalize()}: {analysis[similarity_key]:.4f}\n")
            
            f.write(f"\n模型特性:\n")
            f.write(f"  ✓ Graph-DA数据增强\n")
            f.write(f"  ✓ E-Graph与S-Graph构建\n")
            f.write(f"  ✓ 滤波器组驱动的多图建模\n")
            f.write(f"  ✓ 自博弈融合策略\n")
            f.write(f"  ✓ 多尺度图卷积网络\n")
            f.write(f"  ✓ 空间图注意力机制\n")
        
        logger.info(f"推理报告已保存: {report_path}")
    
    def run_complete_evaluation(self, 
                               num_samples: int = 100,
                               num_inference_steps: int = 50,
                               guidance_scale: float = 1.0):
        """
        运行完整评估
        
        Args:
            num_samples: 测试样本数
            num_inference_steps: 推理步数
            guidance_scale: 引导尺度
        """
        logger.info("开始完整推理评估")
        
        # 运行推理
        results = self.run_inference(
            num_samples=num_samples,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale
        )
        
        # 分析结果
        analysis = self.analyze_results(results)
        
        # 生成可视化
        self.generate_visualizations(results)
        
        # 生成报告
        self.generate_report(results, analysis)
        
        logger.info(f"完整评估完成，结果保存在: {self.output_dir}")
        
        return results, analysis

def main():
    """
    主函数
    """
    parser = argparse.ArgumentParser(description='Video2EEG-SGGN-Diffusion模型推理')
    parser.add_argument('--model_path', type=str, required=True, help='模型检查点路径')
    parser.add_argument('--data_dir', type=str, required=True, help='数据目录路径')
    parser.add_argument('--output_dir', type=str, default='./sggn_inference_output', help='输出目录')
    parser.add_argument('--num_samples', type=int, default=100, help='测试样本数')
    parser.add_argument('--num_inference_steps', type=int, default=50, help='推理步数')
    parser.add_argument('--guidance_scale', type=float, default=1.0, help='引导尺度')
    parser.add_argument('--device', type=str, default='auto', help='设备类型')
    
    args = parser.parse_args()
    
    try:
        # 创建推理引擎
        inference_engine = SGGNInferenceEngine(
            model_path=args.model_path,
            data_dir=args.data_dir,
            output_dir=args.output_dir,
            device=args.device
        )
        
        # 运行完整评估
        results, analysis = inference_engine.run_complete_evaluation(
            num_samples=args.num_samples,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale
        )
        
        # 打印关键结果
        logger.info("\n=== 推理测试结果摘要 ===")
        logger.info(f"测试样本数: {len(results['metrics'])}")
        logger.info(f"平均推理时间: {analysis.get('inference_time_mean', 0):.3f}s")
        logger.info(f"平均MSE: {analysis.get('mse_mean', 0):.6f}")
        logger.info(f"平均相关性: {analysis.get('mean_correlation_mean', 0):.4f}")
        logger.info(f"整体质量评分: {analysis.get('overall_quality_score', 0):.4f}")
        
    except Exception as e:
        logger.error(f"推理过程中发生错误: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    return 0

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)