#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Video2EEG-SGGN-Diffusion vs Neural Timeseries Diffusion 脑电生成质量对比分析
对比Video2EEG-SGGN-Diffusion与neural_timeseries_diffusion的生成质量

作者: 算法工程师
日期: 2025年1月12日
"""

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal, stats
from scipy.fft import fft, fftfreq
import pandas as pd
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import warnings
import time
import os
import sys
import importlib.util
from pathlib import Path
import json
from datetime import datetime

warnings.filterwarnings('ignore')

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

class EEGQualityEvaluator:
    """
    脑电信号生成质量评估器
    """
    
    def __init__(self, sampling_rate=250):
        self.sampling_rate = sampling_rate
        self.evaluation_results = {}
        
    def temporal_domain_analysis(self, real_eeg, generated_eeg, method_name):
        """
        时域分析
        """
        results = {}
        
        # 基础统计指标
        results['mean_mse'] = mean_squared_error(real_eeg.flatten(), generated_eeg.flatten())
        results['mean_mae'] = mean_absolute_error(real_eeg.flatten(), generated_eeg.flatten())
        
        # 信号幅值分析
        real_amplitude = np.std(real_eeg, axis=-1)
        gen_amplitude = np.std(generated_eeg, axis=-1)
        results['amplitude_correlation'] = np.corrcoef(real_amplitude.flatten(), gen_amplitude.flatten())[0, 1]
        results['amplitude_mse'] = mean_squared_error(real_amplitude.flatten(), gen_amplitude.flatten())
        
        # 信号动态范围
        real_range = np.ptp(real_eeg, axis=-1)  # peak-to-peak
        gen_range = np.ptp(generated_eeg, axis=-1)
        results['range_correlation'] = np.corrcoef(real_range.flatten(), gen_range.flatten())[0, 1]
        
        # 时间序列相关性
        channel_correlations = []
        for i in range(min(real_eeg.shape[1], generated_eeg.shape[1])):
            for j in range(min(real_eeg.shape[0], generated_eeg.shape[0])):
                corr = np.corrcoef(real_eeg[j, i, :], generated_eeg[j, i, :])[0, 1]
                if not np.isnan(corr):
                    channel_correlations.append(corr)
        
        results['temporal_correlation_mean'] = np.mean(channel_correlations)
        results['temporal_correlation_std'] = np.std(channel_correlations)
        
        return results
    
    def frequency_domain_analysis(self, real_eeg, generated_eeg, method_name):
        """
        频域分析
        """
        results = {}
        
        # 计算功率谱密度
        def compute_psd(signal_data):
            batch_size, n_channels, n_samples = signal_data.shape
            psd_all = []
            
            for b in range(batch_size):
                for c in range(n_channels):
                    freqs, psd = signal.welch(signal_data[b, c, :], 
                                            fs=self.sampling_rate, 
                                            nperseg=min(256, n_samples//4))
                    psd_all.append(psd)
            
            return np.array(psd_all), freqs
        
        real_psd, freqs = compute_psd(real_eeg)
        gen_psd, _ = compute_psd(generated_eeg)
        
        # 功率谱相关性
        psd_correlations = []
        for i in range(min(len(real_psd), len(gen_psd))):
            corr = np.corrcoef(real_psd[i], gen_psd[i])[0, 1]
            if not np.isnan(corr):
                psd_correlations.append(corr)
        
        results['psd_correlation_mean'] = np.mean(psd_correlations)
        results['psd_correlation_std'] = np.std(psd_correlations)
        
        # 频带功率分析
        freq_bands = {
            'delta': (0.5, 4),
            'theta': (4, 8),
            'alpha': (8, 13),
            'beta': (13, 30),
            'gamma': (30, 50)
        }
        
        for band_name, (low_freq, high_freq) in freq_bands.items():
            band_mask = (freqs >= low_freq) & (freqs <= high_freq)
            if np.any(band_mask):
                real_band_power = np.mean(real_psd[:, band_mask], axis=1)
                gen_band_power = np.mean(gen_psd[:, band_mask], axis=1)
                
                band_corr = np.corrcoef(real_band_power, gen_band_power)[0, 1]
                results[f'{band_name}_power_correlation'] = band_corr if not np.isnan(band_corr) else 0
                results[f'{band_name}_power_mse'] = mean_squared_error(real_band_power, gen_band_power)
        
        return results
    
    def comprehensive_evaluation(self, real_eeg, generated_eeg, method_name):
        """
        综合评估
        """
        print(f"\n=== {method_name} 质量评估 ===")
        
        # 各维度分析
        temporal_results = self.temporal_domain_analysis(real_eeg, generated_eeg, method_name)
        frequency_results = self.frequency_domain_analysis(real_eeg, generated_eeg, method_name)
        
        # 合并结果
        all_results = {
            'temporal': temporal_results,
            'frequency': frequency_results
        }
        
        # 保存结果
        self.evaluation_results[method_name] = all_results
        
        # 打印关键指标
        self.print_key_metrics(method_name, all_results)
        
        return all_results
    
    def print_key_metrics(self, method_name, results):
        """
        打印关键指标
        """
        print(f"\n{method_name} 关键指标:")
        print(f"  时域相关性: {results['temporal'].get('temporal_correlation_mean', 0):.3f}")
        print(f"  频域相关性: {results['frequency'].get('psd_correlation_mean', 0):.3f}")
        print(f"  均方误差: {results['temporal'].get('mean_mse', 0):.3f}")

class ModelLoader:
    """
    模型加载器
    """
    
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"使用设备: {self.device}")
    
    def load_sggn_model(self, model_path="./sggn_training_output/best_model.pth"):
        """
        加载Video2EEG-SGGN-Diffusion模型
        """
        try:
            # 导入SGGN模型
            from video2eeg_sggn_diffusion_model import Video2EEGSGGNDiffusion
            
            # 创建模型实例
            model = Video2EEGSGGNDiffusion(
                eeg_channels=62,
                signal_length=200,
                video_feature_dim=512,
                hidden_dim=256,
                num_diffusion_steps=1000
            ).to(self.device)
            
            # 加载预训练权重
            if os.path.exists(model_path):
                try:
                    checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
                    if 'model_state_dict' in checkpoint:
                        model.load_state_dict(checkpoint['model_state_dict'], strict=False)
                    else:
                        model.load_state_dict(checkpoint, strict=False)
                    print(f"✓ SGGN模型权重部分加载成功: {model_path}")
                except Exception as e:
                    print(f"⚠️ 权重加载失败，使用随机初始化: {e}")
            else:
                print(f"⚠️ 权重文件不存在，使用随机初始化: {model_path}")
            
            print("✓ Video2EEG-SGGN-Diffusion模型加载成功")
            return model
            
        except Exception as e:
            print(f"✗ Video2EEG-SGGN-Diffusion模型加载失败: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    def load_ntd_model(self):
        """
        加载Neural Timeseries Diffusion模型
        """
        try:
            # 添加NTD路径
            ntd_path = Path("../neural_timeseries_diffusion-main")
            if ntd_path.exists():
                sys.path.insert(0, str(ntd_path))
            
            from ntd.diffusion_model import Diffusion
            from ntd.networks import CatConv
            from ntd.utils.kernels_and_diffusion_utils import WhiteNoiseProcess
            
            # 创建网络
            network = CatConv(
                signal_length=200,
                signal_channel=62,
                time_dim=128,
                hidden_channel=32,
                in_kernel_size=7,
                out_kernel_size=7,
                slconv_kernel_size=7,
                num_scales=3,
                num_off_diag=10,
                use_pos_emb=True
            )
            
            # 创建噪声采样器
            noise_sampler = WhiteNoiseProcess(sigma_squared=1.0, signal_length=200)
            
            # 创建扩散模型
            model = Diffusion(
                network=network,
                diffusion_time_steps=1000,
                noise_sampler=noise_sampler,
                mal_dist_computer=noise_sampler,
                schedule="linear"
            ).to(self.device)
            
            print("✓ Neural Timeseries Diffusion模型加载成功")
            return model
            
        except Exception as e:
            print(f"✗ Neural Timeseries Diffusion模型加载失败: {e}")
            import traceback
            traceback.print_exc()
            return None

class DataGenerator:
    """
    数据生成器
    """
    
    def __init__(self, device):
        self.device = device
    
    def create_realistic_eeg_data(self, batch_size=4, n_channels=62, n_samples=200):
        """
        创建更真实的EEG数据
        """
        # 生成更符合神经时间序列特点的数据
        np.random.seed(42)
        
        # 基础噪声
        base_signal = np.random.randn(batch_size, n_channels, n_samples) * 0.5
        
        # 添加生理节律
        time_axis = np.linspace(0, 4, n_samples)  # 4秒数据
        
        for b in range(batch_size):
            for c in range(n_channels):
                # Alpha节律 (8-13 Hz)
                alpha_freq = 8 + np.random.rand() * 5
                alpha_component = 2 * np.sin(2 * np.pi * alpha_freq * time_axis)
                
                # Beta节律 (13-30 Hz)
                beta_freq = 13 + np.random.rand() * 17
                beta_component = 1 * np.sin(2 * np.pi * beta_freq * time_axis)
                
                # Theta节律 (4-8 Hz)
                theta_freq = 4 + np.random.rand() * 4
                theta_component = 1.5 * np.sin(2 * np.pi * theta_freq * time_axis)
                
                # 组合信号
                base_signal[b, c, :] += alpha_component + beta_component + theta_component
        
        # 添加空间相关性（相邻通道相关）
        for b in range(batch_size):
            for c in range(1, n_channels):
                correlation_strength = 0.3
                base_signal[b, c, :] += correlation_strength * base_signal[b, c-1, :]
        
        return torch.FloatTensor(base_signal).to(self.device)
    
    def create_video_data(self, batch_size=4, num_frames=200):
        """
        创建模拟视频数据
        """
        video_frames = torch.randn(batch_size, num_frames, 3, 224, 224)
        return video_frames.to(self.device)

class QualityComparator:
    """
    质量对比器
    """
    
    def __init__(self, output_dir="./comparison_output"):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model_loader = ModelLoader()
        self.data_generator = DataGenerator(self.device)
        self.evaluator = EEGQualityEvaluator(sampling_rate=50)  # 200样本/4秒 = 50Hz
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        
    def run_comparison(self, sggn_model_path="./sggn_training_output/best_model.pth"):
        """
        运行对比分析
        """
        print("Video2EEG-SGGN-Diffusion vs Neural Timeseries Diffusion 脑电生成质量对比")
        print("=" * 80)
        
        # 1. 加载模型
        print("\n=== 加载模型 ===")
        sggn_model = self.model_loader.load_sggn_model(sggn_model_path)
        ntd_model = self.model_loader.load_ntd_model()
        
        if sggn_model is None or ntd_model is None:
            print("模型加载失败，无法进行对比")
            return False
        
        # 2. 生成测试数据
        print("\n=== 生成测试数据 ===")
        real_eeg = self.data_generator.create_realistic_eeg_data(batch_size=4, n_channels=62, n_samples=200)
        video_data = self.data_generator.create_video_data(batch_size=4, num_frames=200)
        
        print(f"真实EEG数据形状: {real_eeg.shape}")
        print(f"视频数据形状: {video_data.shape}")
        
        # 3. 生成EEG信号
        print("\n=== 生成EEG信号 ===")
        
        # SGGN模型生成
        sggn_model.eval()
        with torch.no_grad():
            start_time = time.time()
            # 使用训练模式进行推理（基于之前的成功经验）
            timesteps = torch.randint(0, 1000, (4,), device=self.device)
            sggn_eeg = sggn_model(video_data, real_eeg, timesteps)
            sggn_time = time.time() - start_time
        
        print(f"SGGN模型生成时间: {sggn_time:.4f}秒")
        print(f"SGGN生成EEG形状: {sggn_eeg.shape}")
        
        # NTD模型生成
        ntd_model.eval()
        with torch.no_grad():
            start_time = time.time()
            ntd_eeg = ntd_model.sample(num_samples=4, sample_length=200)
            ntd_time = time.time() - start_time
        
        print(f"NTD模型生成时间: {ntd_time:.4f}秒")
        print(f"NTD生成EEG形状: {ntd_eeg.shape}")
        
        # 4. 转换为numpy进行分析
        real_eeg_np = real_eeg.cpu().numpy()
        sggn_eeg_np = sggn_eeg.cpu().numpy()
        ntd_eeg_np = ntd_eeg.cpu().numpy()
        
        # 5. 质量评估
        print("\n=== 质量评估 ===")
        sggn_results = self.evaluator.comprehensive_evaluation(
            real_eeg_np, sggn_eeg_np, "Video2EEG-SGGN-Diffusion"
        )
        
        ntd_results = self.evaluator.comprehensive_evaluation(
            real_eeg_np, ntd_eeg_np, "Neural Timeseries Diffusion"
        )
        
        # 6. 生成对比报告
        print("\n=== 生成对比报告 ===")
        self.generate_detailed_report(sggn_results, ntd_results, sggn_time, ntd_time)
        
        # 7. 保存结果
        self.save_results({
            'sggn_results': sggn_results,
            'ntd_results': ntd_results,
            'sggn_time': sggn_time,
            'ntd_time': ntd_time,
            'timestamp': datetime.now().isoformat()
        })
        
        return True
    
    def generate_detailed_report(self, sggn_results, ntd_results, sggn_time, ntd_time):
        """
        生成详细对比报告
        """
        report = f"""# Video2EEG-SGGN-Diffusion vs Neural Timeseries Diffusion 质量对比报告

## 评估概述

本报告对比了两种脑电信号生成方法的质量：
- **Video2EEG-SGGN-Diffusion**: 基于视频输入的SGGN扩散模型
- **Neural Timeseries Diffusion**: 无条件神经时间序列扩散模型

评估时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## 性能对比

### 生成时间
- **SGGN模型**: {sggn_time:.4f}秒
- **NTD模型**: {ntd_time:.4f}秒
- **速度优势**: {'SGGN' if sggn_time < ntd_time else 'NTD'}模型快 {abs(sggn_time - ntd_time):.4f}秒

### 详细指标对比

#### 时域分析
| 指标 | SGGN模型 | NTD模型 | 优势 |
|------|----------|---------|------|
| 时域相关性 | {sggn_results['temporal'].get('temporal_correlation_mean', 0):.3f} | {ntd_results['temporal'].get('temporal_correlation_mean', 0):.3f} | {'SGGN' if sggn_results['temporal'].get('temporal_correlation_mean', 0) > ntd_results['temporal'].get('temporal_correlation_mean', 0) else 'NTD'} |
| 幅值相关性 | {sggn_results['temporal'].get('amplitude_correlation', 0):.3f} | {ntd_results['temporal'].get('amplitude_correlation', 0):.3f} | {'SGGN' if sggn_results['temporal'].get('amplitude_correlation', 0) > ntd_results['temporal'].get('amplitude_correlation', 0) else 'NTD'} |
| 均方误差 | {sggn_results['temporal'].get('mean_mse', 0):.3f} | {ntd_results['temporal'].get('mean_mse', 0):.3f} | {'SGGN' if sggn_results['temporal'].get('mean_mse', 0) < ntd_results['temporal'].get('mean_mse', 0) else 'NTD'} |
| 平均绝对误差 | {sggn_results['temporal'].get('mean_mae', 0):.3f} | {ntd_results['temporal'].get('mean_mae', 0):.3f} | {'SGGN' if sggn_results['temporal'].get('mean_mae', 0) < ntd_results['temporal'].get('mean_mae', 0) else 'NTD'} |

#### 频域分析
| 指标 | SGGN模型 | NTD模型 | 优势 |
|------|----------|---------|------|
| 功率谱相关性 | {sggn_results['frequency'].get('psd_correlation_mean', 0):.3f} | {ntd_results['frequency'].get('psd_correlation_mean', 0):.3f} | {'SGGN' if sggn_results['frequency'].get('psd_correlation_mean', 0) > ntd_results['frequency'].get('psd_correlation_mean', 0) else 'NTD'} |
| Delta波段相关性 | {sggn_results['frequency'].get('delta_power_correlation', 0):.3f} | {ntd_results['frequency'].get('delta_power_correlation', 0):.3f} | {'SGGN' if sggn_results['frequency'].get('delta_power_correlation', 0) > ntd_results['frequency'].get('delta_power_correlation', 0) else 'NTD'} |
| Theta波段相关性 | {sggn_results['frequency'].get('theta_power_correlation', 0):.3f} | {ntd_results['frequency'].get('theta_power_correlation', 0):.3f} | {'SGGN' if sggn_results['frequency'].get('theta_power_correlation', 0) > ntd_results['frequency'].get('theta_power_correlation', 0) else 'NTD'} |
| Alpha波段相关性 | {sggn_results['frequency'].get('alpha_power_correlation', 0):.3f} | {ntd_results['frequency'].get('alpha_power_correlation', 0):.3f} | {'SGGN' if sggn_results['frequency'].get('alpha_power_correlation', 0) > ntd_results['frequency'].get('alpha_power_correlation', 0) else 'NTD'} |
| Beta波段相关性 | {sggn_results['frequency'].get('beta_power_correlation', 0):.3f} | {ntd_results['frequency'].get('beta_power_correlation', 0):.3f} | {'SGGN' if sggn_results['frequency'].get('beta_power_correlation', 0) > ntd_results['frequency'].get('beta_power_correlation', 0) else 'NTD'} |
| Gamma波段相关性 | {sggn_results['frequency'].get('gamma_power_correlation', 0):.3f} | {ntd_results['frequency'].get('gamma_power_correlation', 0):.3f} | {'SGGN' if sggn_results['frequency'].get('gamma_power_correlation', 0) > ntd_results['frequency'].get('gamma_power_correlation', 0) else 'NTD'} |

## 结论与建议

### 主要发现
1. **生成速度**: {'SGGN模型在生成速度上具有优势' if sggn_time < ntd_time else 'NTD模型在生成速度上具有优势'}
2. **时域特性**: {'SGGN模型在时域相关性方面表现更好' if sggn_results['temporal'].get('temporal_correlation_mean', 0) > ntd_results['temporal'].get('temporal_correlation_mean', 0) else 'NTD模型在时域相关性方面表现更好'}
3. **频域特性**: {'SGGN模型在频域分析方面表现更好' if sggn_results['frequency'].get('psd_correlation_mean', 0) > ntd_results['frequency'].get('psd_correlation_mean', 0) else 'NTD模型在频域分析方面表现更好'}

### 改进建议
1. **对于SGGN模型**: 可以考虑优化扩散过程的噪声调度策略
2. **对于NTD模型**: 可以考虑引入条件信息来提高生成质量
3. **通用建议**: 增加训练数据量和多样性，优化网络架构

---
*报告生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
"""
        
        # 保存报告
        report_path = self.output_dir / "comparison_report.md"
        with open(report_path, 'w', encoding='utf-8') as f:
            f.write(report)
        
        print(f"详细报告已保存至: {report_path}")
    
    def save_results(self, results):
        """
        保存对比结果
        """
        def convert_to_serializable(obj):
            """递归转换numpy类型为Python原生类型"""
            if isinstance(obj, np.float32):
                return float(obj)
            elif isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, dict):
                return {key: convert_to_serializable(value) for key, value in obj.items()}
            elif isinstance(obj, list):
                return [convert_to_serializable(item) for item in obj]
            else:
                return obj
        
        # 转换结果为可序列化格式
        serializable_results = convert_to_serializable(results)
        
        # 保存JSON结果
        results_path = self.output_dir / "comparison_results.json"
        with open(results_path, 'w', encoding='utf-8') as f:
            json.dump(serializable_results, f, indent=2, ensure_ascii=False)
        
        print(f"对比结果已保存至: {results_path}")

def main():
    """
    主函数
    """
    print("Video2EEG-SGGN-Diffusion vs Neural Timeseries Diffusion 质量对比分析")
    print("=" * 80)
    
    # 创建对比器
    comparator = QualityComparator(output_dir="./sggn_vs_ntd_comparison_output")
    
    # 运行对比分析
    success = comparator.run_comparison(
        sggn_model_path="./sggn_training_output/best_model.pth"
    )
    
    if success:
        print("\n=== 对比分析完成 ===")
        print(f"结果保存在: {comparator.output_dir}")
        print("\n生成的文件:")
        print("  - comparison_report.md: 详细对比报告")
        print("  - comparison_results.json: 数值结果")
    else:
        print("\n=== 对比分析失败 ===")
        print("请检查模型路径和依赖项")

if __name__ == "__main__":
    main()