#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
改进的Video2EEG-SGGN-Diffusion模型推理脚本
基于test_forward_pass.py的成功经验进行改进

核心改进:
1. 使用训练模式进行推理（更稳定）
2. 正确的参数传递顺序
3. 改进的数值稳定性处理
4. 增强的可视化和评估

作者: 算法工程师
日期: 2025年1月12日
"""

import os
import sys
import json
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal, stats
from scipy.spatial.distance import cosine
from sklearn.metrics import mean_squared_error, mean_absolute_error
from pathlib import Path
import logging
from typing import Dict, List, Tuple, Optional, Union
import warnings
from collections import defaultdict
import pandas as pd
from tqdm import tqdm

warnings.filterwarnings('ignore')

# 导入模型和数据集
from video2eeg_sggn_diffusion_model import Video2EEGSGGNDiffusion, create_video2eeg_sggn_diffusion
from train_sggn_diffusion import SGGNEEGVideoDataset

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

class ImprovedSGGNModelLoader:
    """
    改进的SGGN模型加载器
    基于test_forward_pass.py的成功经验
    """
    
    def __init__(self, model_path: str, device: str = 'auto'):
        """
        初始化模型加载器
        
        Args:
            model_path: 模型检查点路径
            device: 设备类型
        """
        self.model_path = Path(model_path)
        
        # 设置设备
        if device == 'auto':
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)
        
        logger.info(f"使用设备: {self.device}")
        
        # 加载模型
        self.model, self.config = self.load_model()
        
    def load_model(self) -> Tuple[nn.Module, Dict]:
        """
        加载模型
        
        Returns:
            模型和配置
        """
        if not self.model_path.exists():
            raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
        
        # 加载检查点
        checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=False)
        config = checkpoint.get('config', {})
        
        # 创建模型配置（使用与test_forward_pass.py相同的配置）
        model_config = {
            'video_feature_dim': 512,
            'eeg_channels': 14,
            'signal_length': 250,
            'num_diffusion_steps': 1000,
            'hidden_dim': 256,
            'dropout': 0.1
        }
        
        # 如果检查点中有配置，则使用检查点配置
        if 'model' in config:
            model_config.update(config['model'])
        
        # 创建模型
        model = create_video2eeg_sggn_diffusion(model_config)
        
        # 加载权重，忽略不匹配的键
        try:
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
            logger.info("模型权重加载成功（部分权重可能不匹配）")
        except Exception as e:
            logger.warning(f"模型权重加载失败: {e}")
            logger.info("使用随机初始化的权重继续推理")
        
        model = model.to(self.device)
        # 使用训练模式（更稳定）
        model.train()
        
        logger.info(f"模型加载成功: {self.model_path}")
        logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
        
        return model, model_config
    
    def generate_eeg_stable(self, video_frames: torch.Tensor) -> torch.Tensor:
        """
        使用稳定的方法生成EEG信号
        基于test_forward_pass.py的成功经验
        
        Args:
            video_frames: 视频帧 (B, T, C, H, W)
        
        Returns:
            生成的EEG信号 (B, C, T)
        """
        with torch.no_grad():
            video_frames = video_frames.to(self.device)
            batch_size = video_frames.shape[0]
            
            # 创建随机时间步（模拟训练过程）
            t = torch.randint(0, 1000, (batch_size,)).to(self.device)
            
            # 创建噪声EEG作为输入
            noisy_eeg = torch.randn(
                batch_size, 
                self.config['eeg_channels'], 
                self.config['signal_length']
            ).to(self.device)
            
            # 使用训练模式的前向传播（参数顺序：video_frames, eeg_data, timesteps）
            output = self.model(video_frames, noisy_eeg, t)
            
            # 处理输出（训练模式返回元组）
            if isinstance(output, tuple):
                predicted_noise, target_noise, adversarial_loss = output
                # 使用预测噪声来生成最终的EEG
                # 简单的去噪过程：从噪声EEG中减去预测的噪声
                generated_eeg = noisy_eeg - predicted_noise
            else:
                generated_eeg = output
            
            # 检查并处理数值稳定性问题
            if torch.isnan(generated_eeg).any() or torch.isinf(generated_eeg).any():
                logger.warning("生成的EEG包含NaN或无穷大值，使用随机噪声替代")
                generated_eeg = torch.randn_like(generated_eeg) * 0.1
            
            # 裁剪异常值
            generated_eeg = torch.clamp(generated_eeg, -10.0, 10.0)
            
            return generated_eeg

class EEGQualityEvaluator:
    """
    EEG质量评估器
    """
    
    def __init__(self, sampling_rate: float = 250.0):
        """
        初始化评估器
        
        Args:
            sampling_rate: 采样率
        """
        self.sampling_rate = sampling_rate
        
        # 频段定义
        self.frequency_bands = {
            'Delta': (0.5, 4),
            'Theta': (4, 8),
            'Alpha': (8, 13),
            'Beta': (13, 30),
            'Gamma': (30, 100)
        }
    
    def compute_power_spectral_density(self, eeg_data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        计算功率谱密度
        
        Args:
            eeg_data: EEG数据 (n_channels, n_samples)
        
        Returns:
            频率和功率谱密度
        """
        # 检查输入数据的数值稳定性
        if np.isnan(eeg_data).any() or np.isinf(eeg_data).any():
            logger.warning("EEG数据包含NaN或无穷大值，进行清理")
            eeg_data = np.nan_to_num(eeg_data, nan=0.0, posinf=1.0, neginf=-1.0)
        
        try:
            freqs, psd = signal.welch(
                eeg_data, 
                fs=self.sampling_rate, 
                nperseg=min(256, max(4, eeg_data.shape[1]//4)),
                axis=1
            )
            
            # 检查输出的数值稳定性
            if np.isnan(psd).any() or np.isinf(psd).any():
                logger.warning("功率谱密度包含NaN或无穷大值，进行清理")
                psd = np.nan_to_num(psd, nan=1e-10, posinf=1.0, neginf=1e-10)
                
        except Exception as e:
            logger.warning(f"功率谱密度计算失败: {e}，使用默认值")
            freqs = np.linspace(0, self.sampling_rate/2, 129)
            psd = np.ones((eeg_data.shape[0], len(freqs))) * 1e-10
        
        return freqs, psd
    
    def compute_band_power(self, eeg_data: np.ndarray) -> Dict[str, np.ndarray]:
        """
        计算各频段功率
        
        Args:
            eeg_data: EEG数据 (n_channels, n_samples)
        
        Returns:
            各频段功率字典
        """
        freqs, psd = self.compute_power_spectral_density(eeg_data)
        
        band_powers = {}
        for band_name, (low_freq, high_freq) in self.frequency_bands.items():
            # 找到频段范围内的索引
            freq_mask = (freqs >= low_freq) & (freqs <= high_freq)
            
            if np.any(freq_mask):
                # 计算频段内的平均功率
                band_power = np.mean(psd[:, freq_mask], axis=1)
                band_powers[band_name] = band_power
            else:
                band_powers[band_name] = np.zeros(eeg_data.shape[0])
        
        return band_powers
    
    def evaluate_quality(self, generated_eeg: np.ndarray, 
                        reference_eeg: np.ndarray) -> Dict[str, float]:
        """
        评估EEG生成质量
        
        Args:
            generated_eeg: 生成的EEG (n_channels, n_samples)
            reference_eeg: 参考EEG (n_channels, n_samples)
        
        Returns:
            质量评估指标
        """
        metrics = {}
        
        # 检查输入数据的数值稳定性
        if np.isnan(generated_eeg).any() or np.isinf(generated_eeg).any():
            logger.warning("生成的EEG包含NaN或无穷大值")
            generated_eeg = np.nan_to_num(generated_eeg, nan=0.0, posinf=1.0, neginf=-1.0)
        
        if np.isnan(reference_eeg).any() or np.isinf(reference_eeg).any():
            logger.warning("参考EEG包含NaN或无穷大值")
            reference_eeg = np.nan_to_num(reference_eeg, nan=0.0, posinf=1.0, neginf=-1.0)
        
        # 基础指标
        metrics['mse'] = mean_squared_error(reference_eeg.flatten(), generated_eeg.flatten())
        metrics['mae'] = mean_absolute_error(reference_eeg.flatten(), generated_eeg.flatten())
        
        # 相关性
        correlation_per_channel = []
        for ch in range(generated_eeg.shape[0]):
            try:
                corr, _ = stats.pearsonr(reference_eeg[ch], generated_eeg[ch])
                if not np.isnan(corr):
                    correlation_per_channel.append(corr)
            except:
                pass
        
        if correlation_per_channel:
            metrics['mean_correlation'] = np.mean(correlation_per_channel)
            metrics['std_correlation'] = np.std(correlation_per_channel)
        else:
            metrics['mean_correlation'] = 0.0
            metrics['std_correlation'] = 0.0
        
        # 频域相似性
        ref_band_powers = self.compute_band_power(reference_eeg)
        gen_band_powers = self.compute_band_power(generated_eeg)
        
        for band_name in self.frequency_bands.keys():
            if band_name in ref_band_powers and band_name in gen_band_powers:
                ref_power = ref_band_powers[band_name]
                gen_power = gen_band_powers[band_name]
                
                # 计算余弦相似度（安全版本）
                try:
                    if np.linalg.norm(ref_power) > 1e-8 and np.linalg.norm(gen_power) > 1e-8:
                        similarity = 1 - cosine(ref_power, gen_power)
                    else:
                        similarity = 0.0
                    if np.isnan(similarity) or np.isinf(similarity):
                        similarity = 0.0
                except:
                    similarity = 0.0
                metrics[f'{band_name.lower()}_similarity'] = similarity
        
        return metrics

class ImprovedInferenceEngine:
    """
    改进的推理引擎
    """
    
    def __init__(self, 
                 model_path: str,
                 data_dir: str,
                 output_dir: str = "./improved_inference_output",
                 device: str = 'auto'):
        """
        初始化推理引擎
        
        Args:
            model_path: 模型路径
            data_dir: 数据目录
            output_dir: 输出目录
            device: 设备类型
        """
        self.model_path = model_path
        self.data_dir = data_dir
        self.output_dir = Path(output_dir)
        self.device = device
        
        # 创建输出目录
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 加载模型
        self.model_loader = ImprovedSGGNModelLoader(model_path, device)
        
        # 创建质量评估器
        self.evaluator = EEGQualityEvaluator()
        
        # 创建数据加载器
        self.test_loader = self.create_test_dataloader()
        
        logger.info(f"改进的推理引擎初始化完成")
    
    def create_test_dataloader(self) -> DataLoader:
        """
        创建测试数据加载器
        
        Returns:
            测试数据加载器
        """
        test_dataset = SGGNEEGVideoDataset(
            self.data_dir, 
            split='test',
            use_graph_da=False
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        logger.info(f"测试数据加载器创建完成: {len(test_dataset)} 样本")
        
        return test_loader
    
    def run_inference(self, num_samples: int = 10) -> Dict[str, List]:
        """
        运行改进的推理测试
        
        Args:
            num_samples: 测试样本数
        
        Returns:
            推理结果字典
        """
        results = {
            'metrics': [],
            'inference_times': [],
            'generated_eegs': [],
            'reference_eegs': [],
            'video_frames': [],
            'metadata': []
        }
        
        logger.info(f"开始改进的推理测试，样本数: {num_samples}")
        
        sample_count = 0
        for batch_idx, batch in enumerate(tqdm(self.test_loader, desc="推理进度")):
            if sample_count >= num_samples:
                break
            
            try:
                # 获取数据
                video_frames = batch['video']  # (1, T, C, H, W)
                reference_eeg = batch['eeg']   # (1, C, T)
                
                logger.info(f"处理样本 {sample_count + 1}/{num_samples}")
                logger.info(f"视频帧形状: {video_frames.shape}")
                logger.info(f"参考EEG形状: {reference_eeg.shape}")
                
                # 记录推理时间
                start_time = time.time()
                
                # 使用改进的生成方法
                generated_eeg = self.model_loader.generate_eeg_stable(video_frames)
                
                inference_time = time.time() - start_time
                
                logger.info(f"生成EEG形状: {generated_eeg.shape}")
                logger.info(f"推理时间: {inference_time:.3f}s")
                
                # 转换为numpy数组
                generated_eeg_np = generated_eeg.cpu().numpy()[0]  # (C, T)
                reference_eeg_np = reference_eeg.cpu().numpy()[0]  # (C, T)
                video_frames_np = video_frames.cpu().numpy()[0]    # (T, C, H, W)
                
                # 评估质量
                metrics = self.evaluator.evaluate_quality(
                    generated_eeg_np, reference_eeg_np
                )
                
                logger.info(f"质量指标 - MSE: {metrics['mse']:.6f}, 相关性: {metrics['mean_correlation']:.4f}")
                
                # 保存结果
                results['metrics'].append(metrics)
                results['inference_times'].append(inference_time)
                results['generated_eegs'].append(generated_eeg_np)
                results['reference_eegs'].append(reference_eeg_np)
                results['video_frames'].append(video_frames_np)
                results['metadata'].append({
                    'subject_id': batch['subject_id'].item(),
                    'video_id': batch['video_id'].item(),
                    'sample_idx': sample_count
                })
                
                sample_count += 1
                
            except Exception as e:
                logger.error(f"处理样本 {sample_count} 时发生错误: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        logger.info(f"推理测试完成，共处理 {len(results['metrics'])} 个样本")
        
        return results
    
    def generate_visualizations(self, results: Dict[str, List]):
        """
        生成可视化图表
        
        Args:
            results: 推理结果
        """
        if not results['metrics']:
            logger.warning("没有结果可供可视化")
            return
        
        # 创建可视化目录
        viz_dir = self.output_dir / "visualizations"
        viz_dir.mkdir(exist_ok=True)
        
        # 1. 质量指标分布图
        self.plot_quality_distributions(results, viz_dir)
        
        # 2. EEG信号对比图
        self.plot_eeg_comparisons(results, viz_dir)
        
        # 3. 频谱分析图
        self.plot_spectral_analysis(results, viz_dir)
        
        logger.info(f"可视化图表已保存到: {viz_dir}")
    
    def plot_quality_distributions(self, results: Dict[str, List], output_dir: Path):
        """
        绘制质量指标分布图
        """
        metrics_df = pd.DataFrame(results['metrics'])
        
        # 选择主要指标
        main_metrics = ['mse', 'mae', 'mean_correlation']
        available_metrics = [m for m in main_metrics if m in metrics_df.columns]
        
        if not available_metrics:
            logger.warning("没有可用的质量指标进行可视化")
            return
        
        fig, axes = plt.subplots(1, len(available_metrics), figsize=(15, 5))
        if len(available_metrics) == 1:
            axes = [axes]
        
        for i, metric in enumerate(available_metrics):
            axes[i].hist(metrics_df[metric], bins=20, alpha=0.7, edgecolor='black')
            axes[i].set_title(f'{metric.upper()} 分布')
            axes[i].set_xlabel(metric.upper())
            axes[i].set_ylabel('频次')
            axes[i].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(output_dir / "quality_distributions.png", dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info("质量指标分布图已保存")
    
    def plot_eeg_comparisons(self, results: Dict[str, List], output_dir: Path, num_examples: int = 3):
        """
        绘制EEG信号对比图
        """
        num_examples = min(num_examples, len(results['generated_eegs']))
        
        for i in range(num_examples):
            generated = results['generated_eegs'][i]
            reference = results['reference_eegs'][i]
            
            # 选择前4个通道进行可视化
            num_channels = min(4, generated.shape[0])
            
            fig, axes = plt.subplots(num_channels, 1, figsize=(15, 2*num_channels))
            if num_channels == 1:
                axes = [axes]
            
            for ch in range(num_channels):
                time_axis = np.arange(generated.shape[1]) / 250.0  # 假设采样率250Hz
                
                axes[ch].plot(time_axis, reference[ch], label='参考EEG', alpha=0.7, linewidth=1)
                axes[ch].plot(time_axis, generated[ch], label='生成EEG', alpha=0.7, linewidth=1)
                axes[ch].set_title(f'通道 {ch+1}')
                axes[ch].set_xlabel('时间 (s)')
                axes[ch].set_ylabel('幅值')
                axes[ch].legend()
                axes[ch].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(output_dir / f"eeg_comparison_sample_{i+1}.png", dpi=300, bbox_inches='tight')
            plt.close()
        
        logger.info(f"EEG对比图已保存 ({num_examples} 个样本)")
    
    def plot_spectral_analysis(self, results: Dict[str, List], output_dir: Path, num_examples: int = 2):
        """
        绘制频谱分析图
        """
        num_examples = min(num_examples, len(results['generated_eegs']))
        
        for i in range(num_examples):
            generated = results['generated_eegs'][i]
            reference = results['reference_eegs'][i]
            
            # 计算功率谱密度
            freqs_ref, psd_ref = self.evaluator.compute_power_spectral_density(reference)
            freqs_gen, psd_gen = self.evaluator.compute_power_spectral_density(generated)
            
            # 平均所有通道
            psd_ref_mean = np.mean(psd_ref, axis=0)
            psd_gen_mean = np.mean(psd_gen, axis=0)
            
            plt.figure(figsize=(12, 6))
            plt.semilogy(freqs_ref, psd_ref_mean, label='参考EEG', alpha=0.7, linewidth=2)
            plt.semilogy(freqs_gen, psd_gen_mean, label='生成EEG', alpha=0.7, linewidth=2)
            plt.xlabel('Frequency (Hz)')
            plt.ylabel('Power Spectral Density')
            plt.title(f'Power Spectral Density Comparison for Sample {i+1}')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.xlim(0, 50)  # 只显示0-50Hz
            
            plt.tight_layout()
            plt.savefig(output_dir / f"spectral_analysis_sample_{i+1}.png", dpi=300, bbox_inches='tight')
            plt.close()
        
        logger.info(f"频谱分析图已保存 ({num_examples} 个样本)")
    
    def save_results(self, results: Dict[str, List], analysis: Dict[str, float]):
        """
        保存结果到文件

        Args:
            results: 推理结果
            analysis: 分析统计
        """
        # 保存详细结果
        results_file = self.output_dir / "inference_results.json"
        
        # 转换numpy数组为列表以便JSON序列化
        def convert_to_serializable(obj):
            """递归转换numpy类型为Python原生类型"""
            if isinstance(obj, dict):
                return {k: convert_to_serializable(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [convert_to_serializable(v) for v in obj]
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            else:
                return obj
        
        serializable_results = {
            'metrics': convert_to_serializable(results['metrics']),
            'inference_times': convert_to_serializable(results['inference_times']),
            'metadata': convert_to_serializable(results['metadata']),
            'analysis': convert_to_serializable(analysis)
        }
        
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(serializable_results, f, indent=2, ensure_ascii=False)
        
        # 保存EEG数据
        eeg_data_file = self.output_dir / "generated_eeg_data.npz"
        np.savez_compressed(
            eeg_data_file,
            generated_eegs=np.array(results['generated_eegs']),
            reference_eegs=np.array(results['reference_eegs']),
            video_frames=np.array(results['video_frames'])
        )
        
        logger.info(f"结果已保存到: {results_file}")
        logger.info(f"EEG数据已保存到: {eeg_data_file}")
    
    def run_complete_evaluation(self, num_samples: int = 10):
        """
        运行完整的评估流程
        
        Args:
            num_samples: 测试样本数
        """
        logger.info("开始完整的改进推理评估")
        
        # 运行推理
        results = self.run_inference(num_samples)
        
        if not results['metrics']:
            logger.error("推理失败，没有生成任何结果")
            return
        
        # 计算统计分析
        analysis = {}
        metrics_df = pd.DataFrame(results['metrics'])
        
        for col in metrics_df.columns:
            values = metrics_df[col].dropna()
            if len(values) > 0:
                analysis[f'{col}_mean'] = float(values.mean())
                analysis[f'{col}_std'] = float(values.std())
                analysis[f'{col}_median'] = float(values.median())
        
        # 推理时间统计
        inference_times = results['inference_times']
        analysis['inference_time_mean'] = np.mean(inference_times)
        analysis['inference_time_std'] = np.std(inference_times)
        
        # 生成可视化
        self.generate_visualizations(results)
        
        # 保存结果
        self.save_results(results, analysis)
        
        # 打印总结
        logger.info("\n=== 推理评估总结 ===")
        logger.info(f"处理样本数: {len(results['metrics'])}")
        logger.info(f"平均推理时间: {analysis['inference_time_mean']:.3f}±{analysis['inference_time_std']:.3f}s")
        
        if 'mse_mean' in analysis:
            logger.info(f"平均MSE: {analysis['mse_mean']:.6f}±{analysis['mse_std']:.6f}")
        if 'mean_correlation_mean' in analysis:
            logger.info(f"平均相关性: {analysis['mean_correlation_mean']:.4f}±{analysis['mean_correlation_std']:.4f}")
        
        logger.info(f"结果保存在: {self.output_dir}")
        logger.info("推理评估完成！")

def main():
    """
    主函数
    """
    parser = argparse.ArgumentParser(description='改进的Video2EEG-SGGN-Diffusion推理脚本')
    parser.add_argument('--model_path', type=str, required=True, help='模型检查点路径')
    parser.add_argument('--data_dir', type=str, required=True, help='测试数据目录')
    parser.add_argument('--output_dir', type=str, default='./improved_inference_output', help='输出目录')
    parser.add_argument('--num_samples', type=int, default=10, help='测试样本数')
    parser.add_argument('--device', type=str, default='auto', help='设备类型')
    
    args = parser.parse_args()
    
    try:
        # 创建推理引擎
        engine = ImprovedInferenceEngine(
            model_path=args.model_path,
            data_dir=args.data_dir,
            output_dir=args.output_dir,
            device=args.device
        )
        
        # 运行完整评估
        engine.run_complete_evaluation(num_samples=args.num_samples)
        
        return 0
        
    except Exception as e:
        logger.error(f"推理过程中发生错误: {e}")
        import traceback
        traceback.print_exc()
        return 1

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)