#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Video2EEG-SGGN-Diffusion模型
基于MGIF框架的多尺度图神经网络扩散模型，用于从视频生成EEG信号

核心特性:
1. Graph-DA数据增强 - 基于图结构的数据增强方法
2. E-Graph与S-Graph构建 - 空间先验和信号统计特性的双图建模
3. 滤波器组驱动的多图建模 - 多频段图结构建模
4. 自博弈融合策略 - 对抗博弈的信息融合机制

作者: 算法工程师
日期: 2025年1月12日
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
import math
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings('ignore')

# ==================== 图构建工具函数 ====================

def create_electrode_adjacency_matrix(electrode_positions: np.ndarray, 
                                     threshold: float = 0.3) -> np.ndarray:
    """
    基于电极位置创建E-Graph邻接矩阵
    
    Args:
        electrode_positions: 电极位置坐标 (n_channels, 3)
        threshold: 距离阈值
    
    Returns:
        邻接矩阵 (n_channels, n_channels)
    """
    # 计算欧几里得距离
    distances = squareform(pdist(electrode_positions, metric='euclidean'))
    
    # 归一化距离
    max_distance = np.max(distances)
    normalized_distances = distances / max_distance
    
    # 创建邻接矩阵（距离越近权重越大）
    adjacency = np.exp(-normalized_distances / threshold)
    
    # 对角线设为0（自连接）
    np.fill_diagonal(adjacency, 0)
    
    return adjacency

def create_signal_correlation_matrix(eeg_signals: np.ndarray) -> np.ndarray:
    """
    基于信号相关性创建邻接矩阵
    
    Args:
        eeg_signals: EEG信号 (n_channels, n_samples)
    
    Returns:
        相关性邻接矩阵 (n_channels, n_channels)
    """
    n_channels = eeg_signals.shape[0]
    correlation_matrix = np.zeros((n_channels, n_channels))
    
    # 数值稳定性检查和处理
    eeg_signals = np.nan_to_num(eeg_signals, nan=0.0, posinf=1.0, neginf=-1.0)
    
    for i in range(n_channels):
        for j in range(n_channels):
            if i != j:
                try:
                    # 检查信号是否有效
                    signal_i = eeg_signals[i]
                    signal_j = eeg_signals[j]
                    
                    # 检查信号方差是否为零（常数信号）
                    if np.var(signal_i) < 1e-10 or np.var(signal_j) < 1e-10:
                        correlation_matrix[i, j] = 0.0
                    else:
                        corr, _ = pearsonr(signal_i, signal_j)
                        # 处理可能的NaN结果
                        if np.isnan(corr) or np.isinf(corr):
                            correlation_matrix[i, j] = 0.0
                        else:
                            correlation_matrix[i, j] = abs(corr)  # 使用绝对值
                except Exception:
                    # 如果计算失败，设为0
                    correlation_matrix[i, j] = 0.0
    
    return correlation_matrix

def add_gaussian_white_noise(signal: np.ndarray, 
                           target_noise_db: float = 0, 
                           mode: str = 'noisePower') -> np.ndarray:
    """
    向信号添加高斯白噪声（Graph-DA数据增强）
    
    Args:
        signal: 输入信号
        target_noise_db: 目标噪声功率（dB）
        mode: 噪声模式
    
    Returns:
        添加噪声后的信号
    """
    signal_power = np.mean(signal ** 2)
    noise_power = signal_power * (10 ** (target_noise_db / 10))
    noise = np.random.normal(0, np.sqrt(noise_power), signal.shape)
    
    return signal + noise

# ==================== 图神经网络组件 ====================

class GraphConvolution(nn.Module):
    """
    图卷积层
    """
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
    
    def forward(self, input: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        
        Args:
            input: 输入特征 (batch_size, n_nodes, in_features)
            adj: 邻接矩阵 (batch_size, n_nodes, n_nodes) 或 (n_nodes, n_nodes)
        
        Returns:
            输出特征 (batch_size, n_nodes, out_features)
        """
        support = torch.matmul(input, self.weight)
        output = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

class MultiScaleGraphConv(nn.Module):
    """
    多尺度图卷积网络
    """
    def __init__(self, in_features: int, hidden_features: int, out_features: int, 
                 num_scales: int = 3, dropout: float = 0.1):
        super(MultiScaleGraphConv, self).__init__()
        self.num_scales = num_scales
        
        # 多尺度图卷积层
        self.graph_convs = nn.ModuleList([
            GraphConvolution(in_features if i == 0 else hidden_features, 
                           hidden_features)
            for i in range(num_scales)
        ])
        
        # 输出层
        self.output_conv = GraphConvolution(hidden_features * num_scales, out_features)
        
        # 激活函数和dropout
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, adj_matrices: List[torch.Tensor]) -> torch.Tensor:
        """
        前向传播
        
        Args:
            x: 输入特征 (batch_size, n_nodes, in_features)
            adj_matrices: 多尺度邻接矩阵列表
        
        Returns:
            输出特征 (batch_size, n_nodes, out_features)
        """
        scale_outputs = []
        
        for i, (conv, adj) in enumerate(zip(self.graph_convs, adj_matrices)):
            if i == 0:
                h = conv(x, adj)
            else:
                h = conv(h, adj)
            h = self.activation(h)
            h = self.dropout(h)
            scale_outputs.append(h)
        
        # 连接多尺度特征
        combined_features = torch.cat(scale_outputs, dim=-1)
        output = self.output_conv(combined_features, adj_matrices[0])
        
        return output

class SpatialGraphAttention(nn.Module):
    """
    空间图注意力机制
    """
    def __init__(self, in_features: int, out_features: int, num_heads: int = 8):
        super(SpatialGraphAttention, self).__init__()
        self.num_heads = num_heads
        self.out_features = out_features
        self.head_dim = out_features // num_heads
        
        assert out_features % num_heads == 0, "out_features must be divisible by num_heads"
        
        self.query = nn.Linear(in_features, out_features)
        self.key = nn.Linear(in_features, out_features)
        self.value = nn.Linear(in_features, out_features)
        self.output = nn.Linear(out_features, out_features)
        
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        """
        前向传播
        
        Args:
            x: 输入特征 (batch_size, n_nodes, in_features)
            adj: 邻接矩阵 (n_nodes, n_nodes)
        
        Returns:
            输出特征 (batch_size, n_nodes, out_features)
        """
        batch_size, n_nodes, _ = x.shape
        
        # 计算Q, K, V
        Q = self.query(x).view(batch_size, n_nodes, self.num_heads, self.head_dim)
        K = self.key(x).view(batch_size, n_nodes, self.num_heads, self.head_dim)
        V = self.value(x).view(batch_size, n_nodes, self.num_heads, self.head_dim)
        
        # 转置以便计算注意力
        Q = Q.transpose(1, 2)  # (batch_size, num_heads, n_nodes, head_dim)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # 计算注意力分数
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # 应用邻接矩阵掩码
        # 确保adj是2维张量 (n_nodes, n_nodes)
        while adj.dim() > 2:
            adj = adj.squeeze(0)
        if adj.dim() == 1:
            adj = adj.unsqueeze(0)
        adj_mask = adj.unsqueeze(0).unsqueeze(0).expand(batch_size, self.num_heads, -1, -1)
        attention_scores = attention_scores * adj_mask + (1 - adj_mask) * (-1e9)
        
        # Softmax
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # 应用注意力权重
        attended_values = torch.matmul(attention_weights, V)
        
        # 重新组织维度
        attended_values = attended_values.transpose(1, 2).contiguous().view(
            batch_size, n_nodes, self.out_features
        )
        
        # 输出投影
        output = self.output(attended_values)
        
        return output

# ==================== 滤波器组和多图建模 ====================

class FilterBankProcessor(nn.Module):
    """
    滤波器组处理器，用于多频段信号分析
    """
    def __init__(self, sampling_rate: int = 250, 
                 frequency_bands: List[Tuple[float, float]] = None):
        super(FilterBankProcessor, self).__init__()
        self.sampling_rate = sampling_rate
        
        if frequency_bands is None:
            # 默认EEG频段
            self.frequency_bands = [
                (0.5, 4),    # Delta
                (4, 8),      # Theta
                (8, 13),     # Alpha
                (13, 30),    # Beta
                (30, 100)    # Gamma
            ]
        else:
            self.frequency_bands = frequency_bands
        
        self.num_bands = len(self.frequency_bands)
        
    def apply_bandpass_filter(self, signal: torch.Tensor, 
                            low_freq: float, high_freq: float) -> torch.Tensor:
        """
        应用带通滤波器
        
        Args:
            signal: 输入信号 (batch_size, n_channels, n_samples)
            low_freq: 低频截止频率
            high_freq: 高频截止频率
        
        Returns:
            滤波后的信号
        """
        # 简化的滤波器实现（实际应用中可使用更复杂的滤波器）
        nyquist = self.sampling_rate / 2
        low_normalized = low_freq / nyquist
        high_normalized = high_freq / nyquist
        
        # 使用简单的频域滤波
        fft_signal = torch.fft.fft(signal, dim=-1)
        freqs = torch.fft.fftfreq(signal.shape[-1], 1/self.sampling_rate)
        
        # 创建滤波器掩码
        mask = ((torch.abs(freqs) >= low_freq) & (torch.abs(freqs) <= high_freq)).float()
        mask = mask.to(signal.device)
        
        # 应用滤波器
        filtered_fft = fft_signal * mask
        filtered_signal = torch.fft.ifft(filtered_fft, dim=-1).real
        
        return filtered_signal
    
    def forward(self, signal: torch.Tensor) -> List[torch.Tensor]:
        """
        对信号应用多个频段滤波器
        
        Args:
            signal: 输入信号 (batch_size, n_channels, n_samples)
        
        Returns:
            多频段滤波后的信号列表
        """
        filtered_signals = []
        
        for low_freq, high_freq in self.frequency_bands:
            filtered_signal = self.apply_bandpass_filter(signal, low_freq, high_freq)
            filtered_signals.append(filtered_signal)
        
        return filtered_signals

class MultiGraphBuilder(nn.Module):
    """
    多图构建器，基于滤波器组构建多个S-Graph
    """
    def __init__(self, n_channels: int, frequency_bands: List[Tuple[float, float]] = None):
        super(MultiGraphBuilder, self).__init__()
        self.n_channels = n_channels
        
        # 滤波器组
        self.filter_bank = FilterBankProcessor(frequency_bands=frequency_bands)
        self.num_bands = self.filter_bank.num_bands
        
    def build_correlation_graphs(self, eeg_signals: torch.Tensor) -> List[torch.Tensor]:
        """
        为每个频段构建相关性图
        
        Args:
            eeg_signals: EEG信号 (batch_size, n_channels, n_samples)
        
        Returns:
            多频段相关性图列表
        """
        # 应用滤波器组
        filtered_signals = self.filter_bank(eeg_signals)
        
        correlation_graphs = []
        
        for filtered_signal in filtered_signals:
            batch_size = filtered_signal.shape[0]
            batch_graphs = []
            
            for b in range(batch_size):
                # 计算相关性矩阵
                signal_batch = filtered_signal[b].cpu().numpy()  # 应该是 (n_channels, n_samples)
                
                # 确保维度正确
                if len(signal_batch.shape) != 2:
                    print(f"警告: signal_batch维度不正确: {signal_batch.shape}")
                    # 如果是3D，取第一个维度
                    if len(signal_batch.shape) == 3:
                        signal_batch = signal_batch[0]
                    else:
                        # 创建默认的相关性矩阵
                        n_channels = self.n_channels
                        corr_matrix = np.eye(n_channels)
                        batch_graphs.append(torch.FloatTensor(corr_matrix))
                        continue
                
                corr_matrix = create_signal_correlation_matrix(signal_batch)
                batch_graphs.append(torch.FloatTensor(corr_matrix))
            
            # 堆叠为批次张量
            batch_correlation_graph = torch.stack(batch_graphs).to(eeg_signals.device)
            correlation_graphs.append(batch_correlation_graph)
        
        return correlation_graphs
    
    def forward(self, eeg_signals: torch.Tensor) -> List[torch.Tensor]:
        """
        前向传播
        
        Args:
            eeg_signals: EEG信号 (batch_size, n_channels, n_samples)
        
        Returns:
            多频段相关性图列表
        """
        return self.build_correlation_graphs(eeg_signals)

# ==================== 自博弈融合策略 ====================

class SelfGameFusion(nn.Module):
    """
    自博弈融合策略，基于博弈论的信息融合
    """
    def __init__(self, feature_dim: int, num_graphs: int = 2):
        super(SelfGameFusion, self).__init__()
        self.feature_dim = feature_dim
        self.num_graphs = num_graphs
        
        # 攻击者网络（模拟电极缺失）
        self.attacker = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Linear(feature_dim // 2, num_graphs),
            nn.Softmax(dim=-1)
        )
        
        # 防御者网络（融合权重）
        self.defender = nn.Sequential(
            nn.Linear(feature_dim * num_graphs, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, num_graphs),
            nn.Softmax(dim=-1)
        )
        
        # 对抗损失权重
        self.adversarial_weight = nn.Parameter(torch.tensor(0.1))
        
    def forward(self, egraph_features: torch.Tensor, 
                sgraph_features: torch.Tensor, 
                training: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        自博弈融合前向传播
        
        Args:
            egraph_features: E-Graph特征 (batch_size, n_nodes, feature_dim)
            sgraph_features: S-Graph特征 (batch_size, n_nodes, feature_dim)
            training: 是否为训练模式
        
        Returns:
            融合特征和对抗损失
        """
        batch_size, n_nodes, feature_dim = egraph_features.shape
        
        # 连接特征用于防御者
        combined_features = torch.cat([egraph_features, sgraph_features], dim=-1)
        combined_flat = combined_features.view(batch_size * n_nodes, -1)
        
        # 防御者计算融合权重
        fusion_weights = self.defender(combined_flat)  # (batch_size * n_nodes, num_graphs)
        fusion_weights = fusion_weights.view(batch_size, n_nodes, self.num_graphs)
        
        # 加权融合
        egraph_weighted = egraph_features * fusion_weights[:, :, 0:1]
        sgraph_weighted = sgraph_features * fusion_weights[:, :, 1:2]
        fused_features = egraph_weighted + sgraph_weighted
        
        adversarial_loss = torch.tensor(0.0, device=egraph_features.device)
        
        if training:
            # 攻击者模拟电极缺失
            egraph_flat = egraph_features.view(batch_size * n_nodes, -1)
            
            # 添加数值稳定性检查
            if torch.isnan(egraph_flat).any() or torch.isinf(egraph_flat).any():
                adversarial_loss = torch.tensor(0.0, device=egraph_features.device)
            else:
                attack_weights = self.attacker(egraph_flat)
                attack_weights = attack_weights.view(batch_size, n_nodes, self.num_graphs)
                
                # 对抗损失：攻击者试图最大化融合误差
                attack_fused = egraph_features * attack_weights[:, :, 0:1] + sgraph_features * attack_weights[:, :, 1:2]
                adversarial_loss = F.mse_loss(fused_features, attack_fused) * torch.clamp(self.adversarial_weight, min=0.0, max=0.01)
                
                # 限制对抗损失范围
                adversarial_loss = torch.clamp(adversarial_loss, min=0.0, max=10.0)
        
        return fused_features, adversarial_loss

# ==================== 主模型架构 ====================

class Video2EEGSGGNDiffusion(nn.Module):
    """
    Video2EEG-SGGN-Diffusion主模型
    基于MGIF框架的多尺度图神经网络扩散模型
    """
    def __init__(self, 
                 eeg_channels: int = 62,
                 signal_length: int = 200,
                 video_feature_dim: int = 512,
                 hidden_dim: int = 256,
                 num_diffusion_steps: int = 1000,
                 electrode_positions: Optional[np.ndarray] = None,
                 frequency_bands: Optional[List[Tuple[float, float]]] = None):
        super(Video2EEGSGGNDiffusion, self).__init__()
        
        self.eeg_channels = eeg_channels
        self.signal_length = signal_length
        self.video_feature_dim = video_feature_dim
        self.hidden_dim = hidden_dim
        self.num_diffusion_steps = num_diffusion_steps
        
        # 默认电极位置（简化的10-20系统）
        if electrode_positions is None:
            self.electrode_positions = self._create_default_electrode_positions()
        else:
            self.electrode_positions = electrode_positions
        
        # 创建E-Graph邻接矩阵
        self.register_buffer('egraph_adj', 
                           torch.FloatTensor(create_electrode_adjacency_matrix(self.electrode_positions)))
        
        # 视频编码器
        self.video_encoder = self._build_video_encoder()
        
        # 多图构建器
        self.multi_graph_builder = MultiGraphBuilder(eeg_channels, frequency_bands)
        
        # E-Graph处理器
        self.egraph_processor = MultiScaleGraphConv(
            in_features=signal_length,
            hidden_features=hidden_dim,
            out_features=hidden_dim,
            num_scales=3
        )
        
        # S-Graph处理器（多频段）
        self.sgraph_processors = nn.ModuleList([
            MultiScaleGraphConv(
                in_features=signal_length,
                hidden_features=hidden_dim,
                out_features=hidden_dim,
                num_scales=3
            ) for _ in range(self.multi_graph_builder.num_bands)
        ])
        
        # 空间图注意力
        self.spatial_attention = SpatialGraphAttention(
            in_features=hidden_dim,
            out_features=hidden_dim,
            num_heads=8
        )
        
        # 自博弈融合
        self.self_game_fusion = SelfGameFusion(
            feature_dim=hidden_dim,
            num_graphs=2
        )
        
        # 扩散模型组件
        self.diffusion_model = self._build_diffusion_model()
        
        # 时间嵌入
        self.time_embedding = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 输出投影
        self.output_projection = nn.Linear(hidden_dim, signal_length)
        
        # Graph-DA数据增强参数
        self.noise_levels = [0, 5, 10, 15]  # dB
        
    def _create_default_electrode_positions(self) -> np.ndarray:
        """
        创建默认的电极位置（简化版10-20系统）
        
        Returns:
            电极位置坐标 (n_channels, 3)
        """
        # 简化的球面坐标系统
        positions = np.random.randn(self.eeg_channels, 3)
        
        # 归一化到单位球面
        norms = np.linalg.norm(positions, axis=1, keepdims=True)
        positions = positions / norms
        
        return positions
    
    def _build_video_encoder(self) -> nn.Module:
        """
        构建视频编码器
        
        Returns:
            视频编码器模块
        """
        class AdaptiveVideoEncoder(nn.Module):
            def __init__(self, video_feature_dim):
                super().__init__()
                self.conv_layers = nn.Sequential(
                    # 3D卷积用于时空特征提取
                    nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3)),
                    nn.BatchNorm3d(64),
                    nn.ReLU(),
                    nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
                    
                    nn.Conv3d(64, 128, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)),
                    nn.BatchNorm3d(128),
                    nn.ReLU(),
                    nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
                    
                    nn.Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
                    nn.BatchNorm3d(256),
                    nn.ReLU(),
                    nn.AdaptiveAvgPool3d((None, 4, 4)),  # 固定空间维度为4x4
                )
                
                self.video_feature_dim = video_feature_dim
                self.projection = None  # 延迟初始化
                
            def forward(self, x):
                # 处理不同的输入格式
                if x.dim() == 4:  # (T, C, H, W) - 数据集格式
                    T, C, H, W = x.shape
                    x = x.unsqueeze(0)  # 添加batch维度: (1, T, C, H, W)
                    B = 1
                elif x.dim() == 5:  # (B, T, C, H, W) - 标准格式
                    B, T, C, H, W = x.shape
                else:
                    raise ValueError(f"Unexpected input shape: {x.shape}")
                
                # 重塑为3D卷积输入格式: (B, C, T, H, W)
                x = x.permute(0, 2, 1, 3, 4)
                
                # 通过卷积层
                x = self.conv_layers(x)  # (B, 256, T, 4, 4)
                
                # 展平空间维度
                B, C, T, H, W = x.shape
                x = x.permute(0, 2, 1, 3, 4)  # (B, T, C, H, W)
                x = x.reshape(B, T, -1)  # (B, T, C*H*W)
                
                # 延迟初始化投影层
                if self.projection is None:
                    input_dim = x.shape[-1]
                    self.projection = nn.Linear(input_dim, self.video_feature_dim).to(x.device)
                
                # 投影到目标维度
                x = self.projection(x)  # (B, T, video_feature_dim)
                
                return x
        
        return AdaptiveVideoEncoder(self.video_feature_dim)
    
    def _build_diffusion_model(self) -> nn.Module:
        """
        构建扩散模型
        
        Returns:
            扩散模型
        """
        return nn.Sequential(
            nn.Linear(self.hidden_dim + self.video_feature_dim + self.hidden_dim, self.hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            
            nn.Linear(self.hidden_dim * 2, self.hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.ReLU(),
            
            nn.Linear(self.hidden_dim, self.eeg_channels)
        )
    
    def apply_graph_da_augmentation(self, eeg_data: torch.Tensor, 
                                  augmentation_ratio: float = 0.3) -> torch.Tensor:
        """
        应用Graph-DA数据增强
        
        Args:
            eeg_data: EEG数据 (batch_size, n_channels, n_samples)
            augmentation_ratio: 增强比例
        
        Returns:
            增强后的EEG数据
        """
        if not self.training:
            return eeg_data
        
        batch_size = eeg_data.shape[0]
        augmented_data = eeg_data.clone()
        
        # 随机选择一部分样本进行增强
        num_augment = int(batch_size * augmentation_ratio)
        augment_indices = torch.randperm(batch_size)[:num_augment]
        
        for idx in augment_indices:
            # 随机选择噪声级别
            noise_level = np.random.choice(self.noise_levels)
            
            # 转换为numpy进行处理
            signal_np = eeg_data[idx].cpu().numpy()
            
            # 应用噪声增强
            for ch in range(signal_np.shape[0]):
                signal_np[ch] = add_gaussian_white_noise(
                    signal_np[ch], 
                    target_noise_db=noise_level
                )
            
            # 转换回tensor
            augmented_data[idx] = torch.FloatTensor(signal_np).to(eeg_data.device)
        
        return augmented_data
    
    def forward(self, video_frames: torch.Tensor, 
                eeg_data: Optional[torch.Tensor] = None,
                timesteps: Optional[torch.Tensor] = None,
                return_intermediates: bool = False) -> Union[torch.Tensor, Tuple]:
        """
        前向传播
        
        Args:
            video_frames: 视频帧 (batch_size, num_frames, channels, height, width)
            eeg_data: EEG数据 (batch_size, n_channels, n_samples) - 训练时提供
            timesteps: 扩散时间步 (batch_size,) - 训练时提供
            return_intermediates: 是否返回中间结果
        
        Returns:
            生成的EEG信号或训练损失
        """
        batch_size = video_frames.shape[0]
        
        # 1. 视频特征提取
        # 自适应视频编码器会自动处理输入格式
        video_features = self.video_encoder(video_frames)  # (B, T, feature_dim)
        
        if eeg_data is not None and timesteps is not None:
            # 训练模式
            return self._training_forward(video_features, eeg_data, timesteps, return_intermediates)
        else:
            # 推理模式
            return self._inference_forward(video_features, return_intermediates)
    
    def _training_forward(self, video_features: torch.Tensor, 
                         eeg_data: torch.Tensor, 
                         timesteps: torch.Tensor,
                         return_intermediates: bool = False) -> Union[torch.Tensor, Tuple]:
        """
        训练模式前向传播
        """
        # Graph-DA数据增强
        augmented_eeg = self.apply_graph_da_augmentation(eeg_data)
        
        # 2. 构建多图结构
        # E-Graph (基于电极位置)
        egraph_adj = self.egraph_adj.unsqueeze(0).expand(eeg_data.shape[0], -1, -1)
        
        # S-Graph (基于信号相关性，多频段)
        sgraph_adj_list = self.multi_graph_builder(augmented_eeg)
        
        # 3. 图神经网络特征提取
        # E-Graph处理
        # 转换维度: (B, C, T) -> (B, C, T) 作为图节点特征
        # 图卷积期望输入: (batch_size, n_nodes, in_features)
        # 这里将时间维度作为特征维度
        egraph_input = augmented_eeg  # (B, C, T)
        
        # 确保输入维度正确: (B, C, T) -> (B, C, T) 其中 C=n_nodes, T=in_features
        # egraph_input 已经是正确的形状 (batch_size, n_nodes, in_features)
        egraph_features = self.egraph_processor(
            egraph_input, 
            [egraph_adj] * 3  # 多尺度使用相同的邻接矩阵
        )  # (B, C, hidden_dim)
        
        # S-Graph处理（多频段）
        sgraph_features_list = []
        for i, (processor, sgraph_adj) in enumerate(zip(self.sgraph_processors, sgraph_adj_list)):
            sgraph_feat = processor(
                augmented_eeg,
                [sgraph_adj] * 3  # 多尺度
            )
            sgraph_features_list.append(sgraph_feat)
        
        # 融合多频段S-Graph特征
        sgraph_features = torch.stack(sgraph_features_list, dim=0).mean(dim=0)
        
        # 4. 空间图注意力
        egraph_features = self.spatial_attention(egraph_features, egraph_adj)
        sgraph_features = self.spatial_attention(sgraph_features, sgraph_adj_list[0])
        
        # 5. 自博弈融合
        fused_features, adversarial_loss = self.self_game_fusion(
            egraph_features, sgraph_features, training=True
        )
        
        # 6. 扩散过程
        # 添加噪声到目标EEG
        noise = torch.randn_like(eeg_data)
        noisy_eeg = self._add_noise(eeg_data, noise, timesteps)
        
        # 时间嵌入
        time_emb = self.time_embedding(timesteps.float().unsqueeze(-1))  # (B, hidden_dim)
        
        # 准备扩散模型输入
        # 将图特征聚合到时间维度
        graph_features_pooled = torch.mean(fused_features, dim=1)  # (B, hidden_dim)
        video_features_pooled = torch.mean(video_features, dim=1)  # (B, video_feature_dim)
        
        # 连接所有特征
        diffusion_input = torch.cat([
            graph_features_pooled,
            video_features_pooled,
            time_emb
        ], dim=-1)  # (B, hidden_dim + video_feature_dim + hidden_dim)
        
        # 预测噪声
        predicted_noise = self.diffusion_model(diffusion_input)  # (B, n_channels)
        
        # 扩展到时间维度
        predicted_noise = predicted_noise.unsqueeze(-1).expand(-1, -1, self.signal_length)
        
        if return_intermediates:
            intermediates = {
                'egraph_features': egraph_features,
                'sgraph_features': sgraph_features,
                'fused_features': fused_features,
                'video_features': video_features,
                'adversarial_loss': adversarial_loss
            }
            return predicted_noise, noise, intermediates
        
        return predicted_noise, noise, adversarial_loss
    
    def _inference_forward(self, video_features: torch.Tensor,
                          return_intermediates: bool = False) -> Union[torch.Tensor, Tuple]:
        """
        推理模式前向传播
        """
        batch_size = video_features.shape[0]
        device = video_features.device
        
        # 初始化随机噪声
        eeg_shape = (batch_size, self.eeg_channels, self.signal_length)
        eeg_sample = torch.randn(eeg_shape, device=device)
        
        # 逐步去噪
        for t in reversed(range(self.num_diffusion_steps)):
            timesteps = torch.full((batch_size,), t, device=device, dtype=torch.long)
            
            # 使用当前EEG估计构建图
            with torch.no_grad():
                # E-Graph
                egraph_adj = self.egraph_adj.unsqueeze(0).expand(batch_size, -1, -1)
                
                # S-Graph (使用当前EEG估计)
                sgraph_adj_list = self.multi_graph_builder(eeg_sample)
                
                # 图特征提取
                egraph_features = self.egraph_processor(eeg_sample, [egraph_adj] * 3)
                
                sgraph_features_list = []
                for processor, sgraph_adj in zip(self.sgraph_processors, sgraph_adj_list):
                    sgraph_feat = processor(eeg_sample, [sgraph_adj] * 3)
                    sgraph_features_list.append(sgraph_feat)
                
                sgraph_features = torch.stack(sgraph_features_list, dim=0).mean(dim=0)
                
                # 空间注意力
                egraph_features = self.spatial_attention(egraph_features, egraph_adj)
                sgraph_features = self.spatial_attention(sgraph_features, sgraph_adj_list[0])
                
                # 自博弈融合
                fused_features, _ = self.self_game_fusion(
                    egraph_features, sgraph_features, training=False
                )
                
                # 时间嵌入
                time_emb = self.time_embedding(timesteps.float().unsqueeze(-1))
                
                # 准备扩散模型输入
                graph_features_pooled = torch.mean(fused_features, dim=1)
                video_features_pooled = torch.mean(video_features, dim=1)
                
                diffusion_input = torch.cat([
                    graph_features_pooled,
                    video_features_pooled,
                    time_emb
                ], dim=-1)
                
                # 预测噪声
                predicted_noise = self.diffusion_model(diffusion_input)
                predicted_noise = predicted_noise.unsqueeze(-1).expand(-1, -1, self.signal_length)
                
                # 去噪步骤 - 添加数值稳定性检查
                alpha_t = self._get_alpha_t(t)
                alpha_t_prev = self._get_alpha_t(t - 1) if t > 0 else torch.tensor(1.0, device=device)
                
                # 确保alpha值的数值稳定性
                alpha_t = torch.clamp(alpha_t, min=1e-8, max=1.0)
                alpha_t_prev = torch.clamp(alpha_t_prev, min=1e-8, max=1.0)
                
                beta_t = 1 - alpha_t / alpha_t_prev if t > 0 else torch.tensor(0.0, device=device)
                beta_t = torch.clamp(beta_t, min=0.0, max=1.0)
                
                # 数值稳定的更新EEG样本
                sqrt_alpha_t = torch.sqrt(alpha_t)
                sqrt_alpha_t = torch.clamp(sqrt_alpha_t, min=1e-8)
                
                # 检查predicted_noise的数值稳定性
                predicted_noise = torch.nan_to_num(predicted_noise, nan=0.0, posinf=1.0, neginf=-1.0)
                
                eeg_sample = (eeg_sample - beta_t * predicted_noise) / sqrt_alpha_t
                
                # 检查更新后的eeg_sample
                eeg_sample = torch.nan_to_num(eeg_sample, nan=0.0, posinf=1.0, neginf=-1.0)
                
                if t > 0:
                    noise = torch.randn_like(eeg_sample)
                    sqrt_beta_t = torch.sqrt(torch.clamp(beta_t, min=0.0))
                    eeg_sample += sqrt_beta_t * noise
                    
                    # 最终检查
                    eeg_sample = torch.nan_to_num(eeg_sample, nan=0.0, posinf=1.0, neginf=-1.0)
        
        if return_intermediates:
            intermediates = {
                'final_egraph_features': egraph_features,
                'final_sgraph_features': sgraph_features,
                'final_fused_features': fused_features,
                'video_features': video_features
            }
            return eeg_sample, intermediates
        
        return eeg_sample
    
    def _add_noise(self, x: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
        """
        向信号添加噪声
        """
        alpha_t = self._get_alpha_t(timesteps)
        alpha_t = alpha_t.view(-1, 1, 1)  # 广播到正确形状
        
        return torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise
    
    def _get_alpha_t(self, t: Union[int, torch.Tensor]) -> torch.Tensor:
        """
        获取扩散过程的alpha值
        """
        # 确定设备
        device = t.device if isinstance(t, torch.Tensor) else torch.device('cpu')
        
        if isinstance(t, int):
            t = torch.tensor(t, dtype=torch.float32, device=device)
        
        # 线性调度
        beta_start = 0.0001
        beta_end = 0.02
        
        betas = torch.linspace(beta_start, beta_end, self.num_diffusion_steps, device=device)
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        
        if isinstance(t, torch.Tensor):
            if t.dim() == 0:
                return alphas_cumprod[t.long()]
            else:
                return alphas_cumprod[t.long()]
        else:
            return alphas_cumprod[t]
    
    def compute_loss(self, predicted_noise: torch.Tensor, 
                    target_noise: torch.Tensor,
                    adversarial_loss: torch.Tensor,
                    eeg_data: torch.Tensor,
                    generated_eeg: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        计算综合损失
        
        Args:
            predicted_noise: 预测噪声
            target_noise: 目标噪声
            adversarial_loss: 对抗损失
            eeg_data: 真实EEG数据
            generated_eeg: 生成的EEG数据（可选）
        
        Returns:
            总损失和损失字典
        """
        # 主要扩散损失
        diffusion_loss = F.mse_loss(predicted_noise, target_noise)
        
        # 对抗损失
        adversarial_loss = adversarial_loss
        
        # 频域损失（如果有生成的EEG）
        frequency_loss = torch.tensor(0.0, device=predicted_noise.device)
        if generated_eeg is not None:
            # 计算功率谱密度损失
            real_fft = torch.fft.fft(eeg_data, dim=-1)
            gen_fft = torch.fft.fft(generated_eeg, dim=-1)
            
            real_psd = torch.abs(real_fft) ** 2
            gen_psd = torch.abs(gen_fft) ** 2
            
            frequency_loss = F.mse_loss(gen_psd, real_psd)
        
        # 检查各项损失是否有效
        if torch.isnan(diffusion_loss) or torch.isinf(diffusion_loss):
            diffusion_loss = torch.tensor(0.0, device=predicted_noise.device)
        if torch.isnan(adversarial_loss) or torch.isinf(adversarial_loss):
            adversarial_loss = torch.tensor(0.0, device=predicted_noise.device)
        if torch.isnan(frequency_loss) or torch.isinf(frequency_loss):
            frequency_loss = torch.tensor(0.0, device=predicted_noise.device)
        
        # 限制对抗损失的范围
        adversarial_loss = torch.clamp(adversarial_loss, min=0.0, max=100.0)
        
        # 总损失 - 大幅降低对抗损失权重
        total_loss = diffusion_loss + 0.001 * adversarial_loss + 0.05 * frequency_loss
        
        loss_dict = {
            'total_loss': total_loss.item(),
            'diffusion_loss': diffusion_loss.item(),
            'adversarial_loss': adversarial_loss.item(),
            'frequency_loss': frequency_loss.item()
        }
        
        return total_loss, loss_dict

# ==================== 模型工厂函数 ====================

def create_video2eeg_sggn_diffusion(config: Dict) -> Video2EEGSGGNDiffusion:
    """
    创建Video2EEG-SGGN-Diffusion模型
    
    Args:
        config: 模型配置字典
    
    Returns:
        模型实例
    """
    return Video2EEGSGGNDiffusion(
        eeg_channels=config.get('eeg_channels', 62),
        signal_length=config.get('signal_length', 200),
        video_feature_dim=config.get('video_feature_dim', 512),
        hidden_dim=config.get('hidden_dim', 256),
        num_diffusion_steps=config.get('num_diffusion_steps', 1000),
        electrode_positions=config.get('electrode_positions', None),
        frequency_bands=config.get('frequency_bands', None)
    )

if __name__ == "__main__":
    # 测试模型
    print("测试Video2EEG-SGGN-Diffusion模型...")
    
    # 创建模型
    model = Video2EEGSGGNDiffusion()
    model.eval()
    
    # 创建测试数据
    batch_size = 2
    num_frames = 10
    video_frames = torch.randn(batch_size, num_frames, 3, 224, 224)
    eeg_data = torch.randn(batch_size, 62, 200)
    timesteps = torch.randint(0, 1000, (batch_size,))
    
    print(f"输入视频形状: {video_frames.shape}")
    print(f"输入EEG形状: {eeg_data.shape}")
    
    # 训练模式测试
    print("\n=== 训练模式测试 ===")
    with torch.no_grad():
        predicted_noise, target_noise, adversarial_loss = model(
            video_frames, eeg_data, timesteps
        )
        print(f"预测噪声形状: {predicted_noise.shape}")
        print(f"目标噪声形状: {target_noise.shape}")
        print(f"对抗损失: {adversarial_loss.item():.4f}")
        
        # 计算损失
        total_loss, loss_dict = model.compute_loss(
            predicted_noise, target_noise, adversarial_loss, eeg_data
        )
        print(f"总损失: {total_loss.item():.4f}")
        print("损失分解:")
        for key, value in loss_dict.items():
            print(f"  {key}: {value:.4f}")
    
    # 推理模式测试
    print("\n=== 推理模式测试 ===")
    with torch.no_grad():
        generated_eeg = model(video_frames)
        print(f"生成EEG形状: {generated_eeg.shape}")
        print(f"生成EEG范围: [{generated_eeg.min().item():.3f}, {generated_eeg.max().item():.3f}]")
    
    print("\n=== 模型特性验证 ===")
    print("✓ Graph-DA数据增强: 高斯白噪声增强")
    print("✓ E-Graph构建: 基于电极位置的空间图")
    print("✓ S-Graph构建: 基于信号相关性的多频段图")
    print("✓ 滤波器组驱动: 多频段图建模")
    print("✓ 自博弈融合: 对抗博弈信息融合")
    print("✓ 多尺度图卷积: 多层次特征提取")
    print("✓ 空间图注意力: 自适应空间权重")
    
    print("\n模型测试完成！")