#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Video2EEG-SGGN-Diffusion模型训练脚本
基于MGIF框架的多GPU并行训练，支持Graph-DA数据增强和自博弈融合策略

核心特性:
1. 多GPU分布式训练
2. Graph-DA数据增强
3. E-Graph与S-Graph构建
4. 滤波器组驱动的多图建模
5. 自博弈融合策略
6. 综合质量评估

作者: 算法工程师
日期: 2025年1月12日
"""

import os
import sys
import json
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal, stats
from pathlib import Path
import logging
from typing import Dict, List, Tuple, Optional
import warnings
from collections import defaultdict
from datetime import datetime

warnings.filterwarnings('ignore')

# 导入模型
from video2eeg_sggn_diffusion_model import Video2EEGSGGNDiffusion, create_video2eeg_sggn_diffusion

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

class SGGNEEGVideoDataset(Dataset):
    """
    SGGN-EEG-Video数据集，支持Graph-DA数据增强
    """
    
    def __init__(self, data_dir: str, split: str = 'train', 
                 use_graph_da: bool = True, augmentation_ratio: float = 0.3):
        """
        初始化数据集
        
        Args:
            data_dir: 数据目录路径
            split: 数据集分割 ('train', 'val', 'test')
            use_graph_da: 是否使用Graph-DA数据增强
            augmentation_ratio: 数据增强比例
        """
        self.data_dir = Path(data_dir)
        self.split = split
        self.use_graph_da = use_graph_da
        self.augmentation_ratio = augmentation_ratio
        
        # 加载数据文件列表
        self.data_files = list(self.data_dir.glob("subject_*_video_*.npz"))
        
        # 加载数据集分割信息
        split_file = self.data_dir / "dataset_split.json"
        if split_file.exists():
            with open(split_file, 'r') as f:
                split_info = json.load(f)
                if split in split_info:
                    split_files = split_info[split]
                    split_filenames = []
                    for f in split_files:
                        if f.startswith('enhanced_processed_data/'):
                            filename = os.path.basename(f)
                            split_filenames.append(filename)
                        else:
                            split_filenames.append(f)
                    self.data_files = [f for f in self.data_files if f.name in split_filenames]
        
        logger.info(f"加载{split}数据集: {len(self.data_files)}个文件")
        
        # 预加载数据信息
        self.data_info = []
        for data_file in self.data_files:
            try:
                data = np.load(data_file)
                
                # 检查数据文件格式
                if 'eeg_data' in data and 'video_data' in data:
                    # 新格式：每个文件包含一个窗口的数据
                    self.data_info.append({
                        'file': data_file,
                        'window_idx': 0,  # 每个文件只有一个窗口
                        'is_augmented': False
                    })
                    
                    # 如果是训练集且启用Graph-DA，添加增强样本
                    if self.split == 'train' and self.use_graph_da:
                        self.data_info.append({
                            'file': data_file,
                            'window_idx': 0,
                            'is_augmented': True
                        })
                        
                elif 'num_windows' in data:
                    # 旧格式：每个文件包含多个窗口
                    num_windows = int(data['num_windows'])
                    
                    # 添加原始窗口
                    for i in range(num_windows):
                        self.data_info.append({
                            'file': data_file,
                            'window_idx': i,
                            'is_augmented': False
                        })
                    
                    # 如果是训练集且启用Graph-DA，添加增强样本
                    if self.split == 'train' and self.use_graph_da:
                        num_augment = int(num_windows * self.augmentation_ratio)
                        for i in range(num_augment):
                            self.data_info.append({
                                'file': data_file,
                                'window_idx': i % num_windows,
                                'is_augmented': True
                            })
                else:
                    logger.warning(f"未知数据格式 {data_file}: 键={list(data.keys())}")
                    continue
                        
            except Exception as e:
                logger.warning(f"加载数据文件失败 {data_file}: {e}")
                continue
        
        logger.info(f"总样本数: {len(self.data_info)} (增强比例: {self.augmentation_ratio if self.use_graph_da else 0})")
    
    def __len__(self):
        return len(self.data_info)
    
    def apply_graph_da_augmentation(self, eeg_data: np.ndarray) -> np.ndarray:
        """
        应用Graph-DA数据增强
        
        Args:
            eeg_data: EEG数据 (n_channels, n_samples)
        
        Returns:
            增强后的EEG数据
        """
        augmented_eeg = eeg_data.copy()
        
        # 随机选择增强策略
        augmentation_type = np.random.choice(['noise', 'channel_dropout', 'time_shift', 'amplitude_scale'])
        
        if augmentation_type == 'noise':
            # 添加高斯白噪声
            noise_level = np.random.uniform(0, 15)  # dB
            for ch in range(augmented_eeg.shape[0]):
                signal_power = np.mean(augmented_eeg[ch] ** 2)
                noise_power = signal_power * (10 ** (noise_level / 10))
                noise = np.random.normal(0, np.sqrt(noise_power), augmented_eeg[ch].shape)
                augmented_eeg[ch] += noise
        
        elif augmentation_type == 'channel_dropout':
            # 随机丢弃通道
            dropout_ratio = np.random.uniform(0.05, 0.15)
            num_dropout = int(augmented_eeg.shape[0] * dropout_ratio)
            dropout_channels = np.random.choice(augmented_eeg.shape[0], num_dropout, replace=False)
            augmented_eeg[dropout_channels] = 0
        
        elif augmentation_type == 'time_shift':
            # 时间偏移
            shift_samples = np.random.randint(-10, 11)
            if shift_samples > 0:
                augmented_eeg = np.roll(augmented_eeg, shift_samples, axis=1)
                augmented_eeg[:, :shift_samples] = 0
            elif shift_samples < 0:
                augmented_eeg = np.roll(augmented_eeg, shift_samples, axis=1)
                augmented_eeg[:, shift_samples:] = 0
        
        elif augmentation_type == 'amplitude_scale':
            # 幅度缩放
            scale_factor = np.random.uniform(0.8, 1.2)
            augmented_eeg *= scale_factor
        
        return augmented_eeg
    
    def __getitem__(self, idx):
        """获取数据样本"""
        info = self.data_info[idx]
        
        try:
            # 加载数据文件
            data = np.load(info['file'])
            
            # 根据数据格式加载EEG和视频数据
            if 'eeg_data' in data and 'video_data' in data:
                # 新格式：直接使用eeg_data和video_data
                eeg_data = data['eeg_data'].astype(np.float32)
                video_data = data['video_data'].astype(np.float32)
                
                # 提取subject_id和video_id（如果存在）
                subject_id = 1  # 默认值
                video_id = 1    # 默认值
                
                # 尝试从文件名解析subject_id和video_id
                filename = info['file'].name
                if 'subject_' in filename and 'video_' in filename:
                    try:
                        parts = filename.split('_')
                        for i, part in enumerate(parts):
                            if part == 'subject' and i+1 < len(parts):
                                subject_id = int(parts[i+1])
                            elif part == 'video' and i+1 < len(parts):
                                video_id = int(parts[i+1])
                    except:
                        pass
                        
            else:
                # 旧格式：使用window索引
                eeg_key = f"eeg_window_{info['window_idx']}"
                video_key = f"video_window_{info['window_idx']}"
                
                eeg_data = data[eeg_key].astype(np.float32)
                video_data = data[video_key].astype(np.float32)
                
                subject_id = int(data.get('subject_id', 1))
                video_id = int(data.get('video_id', 1))
            
            # 应用Graph-DA数据增强
            if info['is_augmented']:
                eeg_data = self.apply_graph_da_augmentation(eeg_data)
            
            # 标准化视频数据到[0,1]
            if video_data.max() > 1.0:
                video_data = video_data / 255.0
            
            # 调整视频数据维度 (T, H, W, C) -> (T, C, H, W)
            if len(video_data.shape) == 4:
                video_data = np.transpose(video_data, (0, 3, 1, 2))
            
            # 确保视频数据是float32类型
            video_data = video_data.astype(np.float32)
            
            # 下采样视频以减少内存使用
            if video_data.shape[-2:] != (224, 224):
                T, C, H, W = video_data.shape
                target_h, target_w = 224, 224
                step_h = max(1, H // target_h)
                step_w = max(1, W // target_w)
                video_data = video_data[:, :, ::step_h, ::step_w]
                video_data = video_data[:, :, :target_h, :target_w]
            
            # 进一步降维：使用平均池化减少空间维度
            # 从 (T, C, 224, 224) 降到 (T, C, 32, 32)
            if video_data.shape[-2:] == (224, 224):
                T, C, H, W = video_data.shape
                pool_size = 7  # 224 // 32 = 7
                video_data_pooled = []
                for t in range(T):
                    frame = video_data[t]  # (C, H, W)
                    # 使用平均池化
                    pooled_frame = np.zeros((C, 32, 32), dtype=video_data.dtype)
                    for c in range(C):
                        for i in range(32):
                            for j in range(32):
                                start_i, end_i = i * pool_size, (i + 1) * pool_size
                                start_j, end_j = j * pool_size, (j + 1) * pool_size
                                pooled_frame[c, i, j] = np.mean(frame[c, start_i:end_i, start_j:end_j])
                    video_data_pooled.append(pooled_frame)
                video_data = np.stack(video_data_pooled, axis=0)
            
            # 确保数据长度一致
            target_length = 200
            if eeg_data.shape[1] != target_length:
                if eeg_data.shape[1] > target_length:
                    eeg_data = eeg_data[:, :target_length]
                else:
                    pad_length = target_length - eeg_data.shape[1]
                    eeg_data = np.pad(eeg_data, ((0, 0), (0, pad_length)), mode='constant')
            
            target_frames = 200
            if video_data.shape[0] != target_frames:
                if video_data.shape[0] > target_frames:
                    video_data = video_data[:target_frames]
                else:
                    last_frame = video_data[-1:]
                    repeat_times = target_frames - video_data.shape[0]
                    repeated_frames = np.repeat(last_frame, repeat_times, axis=0)
                    video_data = np.concatenate([video_data, repeated_frames], axis=0)
            
            # 调整视频数据维度以匹配模型期望: (T, C, H, W) -> (T, C, H, W)
            # 模型期望输入是 (B, T, C, H, W)，所以这里保持 (T, C, H, W)
            
            return {
                'eeg': torch.FloatTensor(eeg_data),
                'video': torch.FloatTensor(video_data),  # (T, C, H, W)
                'subject_id': subject_id,
                'video_id': video_id,
                'is_augmented': info['is_augmented']
            }
            
        except Exception as e:
            logger.error(f"加载样本失败 {idx}: {e}")
            # 返回零数据作为fallback
            return {
                'eeg': torch.zeros(62, 200),
                'video': torch.zeros(200, 3, 224, 224),
                'subject_id': 0,
                'video_id': 0,
                'is_augmented': False
            }

class SGGNModelTrainer:
    """
    SGGN模型训练器
    """
    
    def __init__(self, 
                 config: Dict,
                 data_dir: str,
                 output_dir: str = "./sggn_training_output",
                 use_distributed: bool = False):
        """
        初始化训练器
        
        Args:
            config: 模型和训练配置
            data_dir: 数据目录
            output_dir: 输出目录
            use_distributed: 是否使用分布式训练
        """
        self.config = config
        self.data_dir = data_dir
        self.output_dir = Path(output_dir)
        self.use_distributed = use_distributed
        
        # 创建输出目录
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 设置设备
        if use_distributed:
            self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
            self.device = torch.device(f'cuda:{self.local_rank}')
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        logger.info(f"使用设备: {self.device}")
        
        # 初始化分布式训练
        if use_distributed:
            dist.init_process_group(backend='nccl')
            self.world_size = dist.get_world_size()
            self.rank = dist.get_rank()
        else:
            self.world_size = 1
            self.rank = 0
        
        # 创建模型
        self.model = self.create_model()
        
        # 创建数据加载器
        self.train_loader, self.val_loader, self.test_loader = self.create_dataloaders()
        
        # 创建优化器和调度器
        self.optimizer = self.create_optimizer()
        self.scheduler = self.create_scheduler()
        
        # 创建损失函数
        self.criterion = nn.MSELoss()
        
        # 混合精度训练
        self.use_mixed_precision = config.get('use_mixed_precision', False)
        if self.use_mixed_precision:
            self.scaler = torch.cuda.amp.GradScaler()
        
        # TensorBoard
        if self.rank == 0:
            self.writer = SummaryWriter(self.output_dir / 'tensorboard')
        
        # 训练状态
        self.current_epoch = 0
        self.best_val_loss = float('inf')
        self.training_history = defaultdict(list)
        
    def create_model(self) -> nn.Module:
        """
        创建SGGN模型
        
        Returns:
            模型实例
        """
        model_config = self.config.get('model', {})
        model = create_video2eeg_sggn_diffusion(model_config)
        model = model.to(self.device)
        
        if self.use_distributed:
            model = DDP(model, device_ids=[self.local_rank], output_device=self.local_rank)
        
        logger.info(f"模型参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
        
        return model
    
    def create_dataloaders(self) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """
        创建数据加载器
        
        Returns:
            训练、验证、测试数据加载器
        """
        # 数据集配置
        dataset_config = self.config.get('dataset', {})
        batch_size = self.config.get('training', {}).get('batch_size', 4)
        num_workers = self.config.get('training', {}).get('num_workers', 4)
        
        # 创建数据集
        train_dataset = SGGNEEGVideoDataset(
            self.data_dir, 
            split='train',
            use_graph_da=dataset_config.get('use_graph_da', True),
            augmentation_ratio=dataset_config.get('augmentation_ratio', 0.3)
        )
        
        val_dataset = SGGNEEGVideoDataset(
            self.data_dir, 
            split='val',
            use_graph_da=False
        )
        
        test_dataset = SGGNEEGVideoDataset(
            self.data_dir, 
            split='test',
            use_graph_da=False
        )
        
        # 创建采样器
        if self.use_distributed:
            train_sampler = DistributedSampler(train_dataset, shuffle=True)
            val_sampler = DistributedSampler(val_dataset, shuffle=False)
            test_sampler = DistributedSampler(test_dataset, shuffle=False)
        else:
            train_sampler = None
            val_sampler = None
            test_sampler = None
        
        # 创建数据加载器
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=train_sampler,
            shuffle=(train_sampler is None),
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            sampler=val_sampler,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            sampler=test_sampler,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )
        
        logger.info(f"数据加载器创建完成:")
        logger.info(f"  训练集: {len(train_dataset)} 样本, {len(train_loader)} 批次")
        logger.info(f"  验证集: {len(val_dataset)} 样本, {len(val_loader)} 批次")
        logger.info(f"  测试集: {len(test_dataset)} 样本, {len(test_loader)} 批次")
        
        return train_loader, val_loader, test_loader
    
    def create_optimizer(self) -> torch.optim.Optimizer:
        """
        创建优化器
        
        Returns:
            优化器
        """
        optimizer_config = self.config.get('optimizer', {})
        optimizer_type = optimizer_config.get('type', 'AdamW')
        learning_rate = optimizer_config.get('learning_rate', 1e-4)
        weight_decay = optimizer_config.get('weight_decay', 1e-5)
        
        if optimizer_type == 'AdamW':
            optimizer = torch.optim.AdamW(
                self.model.parameters(),
                lr=learning_rate,
                weight_decay=weight_decay,
                betas=(0.9, 0.999)
            )
        elif optimizer_type == 'Adam':
            optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=learning_rate,
                weight_decay=weight_decay
            )
        else:
            raise ValueError(f"不支持的优化器类型: {optimizer_type}")
        
        logger.info(f"优化器: {optimizer_type}, 学习率: {learning_rate}, 权重衰减: {weight_decay}")
        
        return optimizer
    
    def create_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
        """
        创建学习率调度器
        
        Returns:
            学习率调度器
        """
        scheduler_config = self.config.get('scheduler', {})
        scheduler_type = scheduler_config.get('type', 'CosineAnnealingLR')
        
        if scheduler_type == 'CosineAnnealingLR':
            T_max = self.config.get('training', {}).get('num_epochs', 100)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=T_max,
                eta_min=1e-6
            )
        elif scheduler_type == 'StepLR':
            step_size = scheduler_config.get('step_size', 30)
            gamma = scheduler_config.get('gamma', 0.1)
            scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                step_size=step_size,
                gamma=gamma
            )
        else:
            scheduler = None
        
        logger.info(f"学习率调度器: {scheduler_type}")
        
        return scheduler
    
    def train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        训练一个epoch
        
        Args:
            epoch: 当前epoch
        
        Returns:
            训练指标字典
        """
        self.model.train()
        
        if self.use_distributed:
            self.train_loader.sampler.set_epoch(epoch)
        
        epoch_losses = defaultdict(list)
        correct_predictions = 0
        total_predictions = 0
        
        for batch_idx, batch in enumerate(self.train_loader):
            # 添加批次开始日志
            if batch_idx % 5 == 0 and self.rank == 0:
                current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                logger.info(f"[{current_time}] 开始处理 Epoch {epoch}, Batch {batch_idx}/{len(self.train_loader)}")
            
            # 数据移动到设备
            video_frames = batch['video'].to(self.device, non_blocking=True)  # (B, T, C, H, W)
            eeg_data = batch['eeg'].to(self.device, non_blocking=True)
            
            # 调整视频数据维度：(B, T, C, H, W) -> (B, T, C, H, W)
            # 当前数据已经是正确格式，无需转置
            
            # 生成随机时间步
            batch_size = video_frames.shape[0]
            timesteps = torch.randint(0, 1000, (batch_size,), device=self.device)
            
            # 添加数据形状日志
            if batch_idx == 0 and self.rank == 0:
                logger.info(f"数据形状 - Video: {video_frames.shape}, EEG: {eeg_data.shape}, Timesteps: {timesteps.shape}")
            
            # 前向传播
            if batch_idx % 5 == 0 and self.rank == 0:
                current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                logger.info(f"[{current_time}] 开始前向传播 Batch {batch_idx}")
                
            try:
                if self.use_mixed_precision:
                    with torch.cuda.amp.autocast():
                        predicted_noise, target_noise, adversarial_loss = self.model(
                            video_frames, eeg_data, timesteps
                        )
                        
                    if batch_idx % 5 == 0 and self.rank == 0:
                        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                        logger.info(f"[{current_time}] 前向传播完成 Batch {batch_idx}")
                        
                    # 计算损失
                    total_loss, loss_dict = self.model.module.compute_loss(
                        predicted_noise, target_noise, adversarial_loss, eeg_data
                    ) if self.use_distributed else self.model.compute_loss(
                        predicted_noise, target_noise, adversarial_loss, eeg_data
                    )
                    
                    # 检查损失值是否有效
                    if not torch.isfinite(total_loss) or torch.isnan(total_loss):
                        logger.warning(f"检测到无效损失值 Batch {batch_idx}: {total_loss.item()}, 跳过此批次")
                        continue
                    
                    # 限制对抗损失的影响
                    if adversarial_loss > 1000:
                        logger.warning(f"对抗损失过大 Batch {batch_idx}: {adversarial_loss.item()}, 进行裁剪")
                        adversarial_loss = torch.clamp(adversarial_loss, max=1000.0)
                        # 重新计算总损失
                        diffusion_loss = loss_dict['diffusion_loss']
                        total_loss = diffusion_loss + 0.01 * adversarial_loss  # 降低对抗损失权重
                    
                    # 反向传播
                    self.optimizer.zero_grad()
                    self.scaler.scale(total_loss).backward()
                    
                    # 梯度裁剪
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    predicted_noise, target_noise, adversarial_loss = self.model(
                        video_frames, eeg_data, timesteps
                    )
                    
                    # 计算损失
                    total_loss, loss_dict = self.model.module.compute_loss(
                        predicted_noise, target_noise, adversarial_loss, eeg_data
                    ) if self.use_distributed else self.model.compute_loss(
                        predicted_noise, target_noise, adversarial_loss, eeg_data
                    )
                    
                    # 检查损失值是否有效
                    if not torch.isfinite(total_loss) or torch.isnan(total_loss):
                        logger.warning(f"检测到无效损失值 Batch {batch_idx}: {total_loss.item()}, 跳过此批次")
                        continue
                    
                    # 限制对抗损失的影响
                    if adversarial_loss > 1000:
                        logger.warning(f"对抗损失过大 Batch {batch_idx}: {adversarial_loss.item()}, 进行裁剪")
                        adversarial_loss = torch.clamp(adversarial_loss, max=1000.0)
                        # 重新计算总损失
                        diffusion_loss = loss_dict['diffusion_loss']
                        total_loss = diffusion_loss + 0.01 * adversarial_loss  # 降低对抗损失权重
                    
                    # 反向传播
                    self.optimizer.zero_grad()
                    total_loss.backward()
                    
                    # 梯度裁剪
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    
                    self.optimizer.step()
                    
            except Exception as e:
                logger.error(f"训练批次 {batch_idx} 出现错误: {e}, 跳过此批次")
                continue
            
            # 记录损失
            for key, value in loss_dict.items():
                epoch_losses[key].append(value)
            
            # 计算准确率（基于噪声预测的相似度）
            with torch.no_grad():
                # 计算预测噪声和目标噪声的余弦相似度
                predicted_flat = predicted_noise.reshape(batch_size, -1)
                target_flat = target_noise.reshape(batch_size, -1)
                
                # 归一化
                predicted_norm = torch.nn.functional.normalize(predicted_flat, dim=1)
                target_norm = torch.nn.functional.normalize(target_flat, dim=1)
                
                # 计算余弦相似度
                cosine_sim = torch.sum(predicted_norm * target_norm, dim=1)
                
                # 将相似度大于0.5的视为正确预测
                correct_batch = (cosine_sim > 0.5).sum().item()
                correct_predictions += correct_batch
                total_predictions += batch_size
            
            # 每10步打印进度（包含时间戳、loss和accuracy）
            if batch_idx % 10 == 0 and self.rank == 0:
                current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                current_accuracy = correct_predictions / max(total_predictions, 1) * 100
                
                logger.info(
                    f"[{current_time}] Epoch {epoch}, Batch {batch_idx}/{len(self.train_loader)}, "
                    f"Loss: {total_loss.item():.4f}, "
                    f"Diffusion: {loss_dict['diffusion_loss']:.4f}, "
                    f"Adversarial: {loss_dict['adversarial_loss']:.4f}, "
                    f"Accuracy: {current_accuracy:.2f}%"
                )
        
        # 计算平均损失和最终准确率
        avg_losses = {key: np.mean(values) for key, values in epoch_losses.items()}
        final_accuracy = correct_predictions / max(total_predictions, 1) * 100
        avg_losses['accuracy'] = final_accuracy
        
        return avg_losses
    
    def validate(self, epoch: int) -> Dict[str, float]:
        """
        验证模型
        
        Args:
            epoch: 当前epoch
        
        Returns:
            验证指标字典
        """
        self.model.eval()
        
        epoch_losses = defaultdict(list)
        correct_predictions = 0
        total_predictions = 0
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(self.val_loader):
                # 数据移动到设备
                video_frames = batch['video'].to(self.device, non_blocking=True)
                eeg_data = batch['eeg'].to(self.device, non_blocking=True)
                
                # 生成随机时间步
                batch_size = video_frames.shape[0]
                timesteps = torch.randint(0, 1000, (batch_size,), device=self.device)
                
                # 前向传播
                predicted_noise, target_noise, adversarial_loss = self.model(
                    video_frames, eeg_data, timesteps
                )
                
                # 计算损失
                total_loss, loss_dict = self.model.module.compute_loss(
                    predicted_noise, target_noise, adversarial_loss, eeg_data
                ) if self.use_distributed else self.model.compute_loss(
                    predicted_noise, target_noise, adversarial_loss, eeg_data
                )
                
                # 记录损失
                for key, value in loss_dict.items():
                    epoch_losses[key].append(value)
                
                # 计算准确率（基于噪声预测的相似度）
                # 计算预测噪声和目标噪声的余弦相似度
                predicted_flat = predicted_noise.reshape(batch_size, -1)
                target_flat = target_noise.reshape(batch_size, -1)
                
                # 归一化
                predicted_norm = torch.nn.functional.normalize(predicted_flat, dim=1)
                target_norm = torch.nn.functional.normalize(target_flat, dim=1)
                
                # 计算余弦相似度
                cosine_sim = torch.sum(predicted_norm * target_norm, dim=1)
                
                # 将相似度大于0.5的视为正确预测
                correct_batch = (cosine_sim > 0.5).sum().item()
                correct_predictions += correct_batch
                total_predictions += batch_size
        
        # 计算平均损失和准确率
        avg_losses = {key: np.mean(values) for key, values in epoch_losses.items()}
        val_accuracy = correct_predictions / max(total_predictions, 1) * 100
        avg_losses['accuracy'] = val_accuracy
        
        return avg_losses
    
    def save_checkpoint(self, epoch: int, val_loss: float, is_best: bool = False):
        """
        保存检查点
        
        Args:
            epoch: 当前epoch
            val_loss: 验证损失
            is_best: 是否为最佳模型
        """
        if self.rank != 0:
            return
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.module.state_dict() if self.use_distributed else self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'val_loss': val_loss,
            'config': self.config,
            'training_history': dict(self.training_history)
        }
        
        # 保存最新检查点
        checkpoint_path = self.output_dir / f'checkpoint_epoch_{epoch}.pth'
        torch.save(checkpoint, checkpoint_path)
        
        # 保存最佳模型
        if is_best:
            best_path = self.output_dir / 'best_model.pth'
            torch.save(checkpoint, best_path)
            logger.info(f"保存最佳模型: {best_path}")
        
        logger.info(f"保存检查点: {checkpoint_path}")
    
    def train(self):
        """
        主训练循环
        """
        num_epochs = self.config.get('training', {}).get('num_epochs', 100)
        save_interval = self.config.get('training', {}).get('save_interval', 10)
        eval_interval = self.config.get('training', {}).get('eval_interval', 5)
        
        logger.info(f"开始训练，总共 {num_epochs} 个epoch")
        
        for epoch in range(num_epochs):
            self.current_epoch = epoch
            
            # 训练
            train_losses = self.train_epoch(epoch)
            
            # 记录训练损失
            for key, value in train_losses.items():
                self.training_history[f'train_{key}'].append(value)
            
            # 验证
            if epoch % eval_interval == 0:
                val_losses = self.validate(epoch)
                
                # 记录验证损失
                for key, value in val_losses.items():
                    self.training_history[f'val_{key}'].append(value)
                
                # 检查是否为最佳模型
                val_total_loss = val_losses['total_loss']
                is_best = val_total_loss < self.best_val_loss
                if is_best:
                    self.best_val_loss = val_total_loss
                
                # TensorBoard记录
                if self.rank == 0:
                    for key, value in train_losses.items():
                        self.writer.add_scalar(f'Train/{key}', value, epoch)
                    for key, value in val_losses.items():
                        self.writer.add_scalar(f'Val/{key}', value, epoch)
                    self.writer.add_scalar('Learning_Rate', self.optimizer.param_groups[0]['lr'], epoch)
                
                # 打印结果
                if self.rank == 0:
                    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    train_acc = train_losses.get('accuracy', 0)
                    val_acc = val_losses.get('accuracy', 0)
                    logger.info(
                        f"[{current_time}] Epoch {epoch}: "
                        f"Train Loss: {train_losses['total_loss']:.4f}, "
                        f"Train Acc: {train_acc:.2f}%, "
                        f"Val Loss: {val_total_loss:.4f}, "
                        f"Val Acc: {val_acc:.2f}%, "
                        f"Best Val Loss: {self.best_val_loss:.4f}"
                    )
            
            # 保存检查点
            if epoch % save_interval == 0 and epoch > 0:
                val_loss = self.training_history.get('val_total_loss', [float('inf')])[-1]
                is_best = val_loss == self.best_val_loss
                self.save_checkpoint(epoch, val_loss, is_best)
            
            # 更新学习率
            if self.scheduler:
                self.scheduler.step()
        
        # 训练结束，保存最终模型
        if self.rank == 0:
            final_val_loss = self.training_history.get('val_total_loss', [float('inf')])[-1]
            self.save_checkpoint(num_epochs, final_val_loss, False)
            
            # 生成训练报告
            self.generate_training_report()
            
            logger.info("训练完成！")
    
    def generate_training_report(self):
        """
        生成训练报告
        """
        if self.rank != 0:
            return
        
        # 保存训练历史
        history_path = self.output_dir / 'training_history.json'
        with open(history_path, 'w') as f:
            json.dump(dict(self.training_history), f, indent=2)
        
        # 绘制训练曲线
        self.plot_training_curves()
        
        # 生成文本报告
        report_path = self.output_dir / 'training_report.txt'
        with open(report_path, 'w', encoding='utf-8') as f:
            f.write("Video2EEG-SGGN-Diffusion 训练报告\n")
            f.write("=" * 50 + "\n\n")
            
            f.write(f"训练配置:\n")
            f.write(f"  模型: Video2EEG-SGGN-Diffusion\n")
            f.write(f"  总epoch数: {self.current_epoch + 1}\n")
            f.write(f"  批次大小: {self.config.get('training', {}).get('batch_size', 4)}\n")
            f.write(f"  学习率: {self.config.get('optimizer', {}).get('learning_rate', 1e-4)}\n")
            f.write(f"  使用Graph-DA: {self.config.get('dataset', {}).get('use_graph_da', True)}\n")
            f.write(f"  数据增强比例: {self.config.get('dataset', {}).get('augmentation_ratio', 0.3)}\n\n")
            
            f.write(f"训练结果:\n")
            f.write(f"  最佳验证损失: {self.best_val_loss:.6f}\n")
            if 'train_total_loss' in self.training_history:
                final_train_loss = self.training_history['train_total_loss'][-1]
                f.write(f"  最终训练损失: {final_train_loss:.6f}\n")
            
            f.write(f"\n模型特性:\n")
            f.write(f"  ✓ Graph-DA数据增强\n")
            f.write(f"  ✓ E-Graph与S-Graph构建\n")
            f.write(f"  ✓ 滤波器组驱动的多图建模\n")
            f.write(f"  ✓ 自博弈融合策略\n")
            f.write(f"  ✓ 多尺度图卷积网络\n")
            f.write(f"  ✓ 空间图注意力机制\n")
        
        logger.info(f"训练报告已保存: {report_path}")
    
    def plot_training_curves(self):
        """
        绘制训练曲线
        """
        if not self.training_history:
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Video2EEG-SGGN-Diffusion 训练曲线', fontsize=16)
        
        # 总损失
        ax = axes[0, 0]
        if 'train_total_loss' in self.training_history:
            ax.plot(self.training_history['train_total_loss'], label='训练损失', color='blue')
        if 'val_total_loss' in self.training_history:
            val_epochs = np.arange(0, len(self.training_history['val_total_loss'])) * 5
            ax.plot(val_epochs, self.training_history['val_total_loss'], label='验证损失', color='red')
        ax.set_title('总损失')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend()
        ax.grid(True)
        
        # 扩散损失
        ax = axes[0, 1]
        if 'train_diffusion_loss' in self.training_history:
            ax.plot(self.training_history['train_diffusion_loss'], label='训练扩散损失', color='blue')
        if 'val_diffusion_loss' in self.training_history:
            val_epochs = np.arange(0, len(self.training_history['val_diffusion_loss'])) * 5
            ax.plot(val_epochs, self.training_history['val_diffusion_loss'], label='验证扩散损失', color='red')
        ax.set_title('扩散损失')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend()
        ax.grid(True)
        
        # 对抗损失
        ax = axes[1, 0]
        if 'train_adversarial_loss' in self.training_history:
            ax.plot(self.training_history['train_adversarial_loss'], label='训练对抗损失', color='blue')
        if 'val_adversarial_loss' in self.training_history:
            val_epochs = np.arange(0, len(self.training_history['val_adversarial_loss'])) * 5
            ax.plot(val_epochs, self.training_history['val_adversarial_loss'], label='验证对抗损失', color='red')
        ax.set_title('对抗损失')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend()
        ax.grid(True)
        
        # 频域损失
        ax = axes[1, 1]
        if 'train_frequency_loss' in self.training_history:
            ax.plot(self.training_history['train_frequency_loss'], label='训练频域损失', color='blue')
        if 'val_frequency_loss' in self.training_history:
            val_epochs = np.arange(0, len(self.training_history['val_frequency_loss'])) * 5
            ax.plot(val_epochs, self.training_history['val_frequency_loss'], label='验证频域损失', color='red')
        ax.set_title('频域损失')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.legend()
        ax.grid(True)
        
        plt.tight_layout()
        
        # 保存图像
        curves_path = self.output_dir / 'training_curves.png'
        plt.savefig(curves_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"训练曲线已保存: {curves_path}")

def create_default_config() -> Dict:
    """
    创建默认配置
    
    Returns:
        默认配置字典
    """
    return {
        'model': {
            'eeg_channels': 62,
            'signal_length': 200,
            'video_feature_dim': 512,
            'hidden_dim': 256,
            'num_diffusion_steps': 1000,
            'frequency_bands': [
                (0.5, 4),    # Delta
                (4, 8),      # Theta
                (8, 13),     # Alpha
                (13, 30),    # Beta
                (30, 100)    # Gamma
            ]
        },
        'dataset': {
            'use_graph_da': True,
            'augmentation_ratio': 0.3
        },
        'training': {
            'num_epochs': 10,
            'batch_size': 1,
        'num_workers': 1,
            'save_interval': 3,
            'eval_interval': 1,
            'use_mixed_precision': True
        },
        'optimizer': {
            'type': 'AdamW',
            'learning_rate': 1e-5,
            'weight_decay': 1e-5
        },
        'scheduler': {
            'type': 'CosineAnnealingLR'
        }
    }

def main():
    """
    主函数
    """
    parser = argparse.ArgumentParser(description='Video2EEG-SGGN-Diffusion模型训练')
    parser.add_argument('--data_dir', type=str, required=True, help='数据目录路径')
    parser.add_argument('--output_dir', type=str, default='./sggn_training_output', help='输出目录')
    parser.add_argument('--config', type=str, help='配置文件路径')
    parser.add_argument('--distributed', action='store_true', help='使用分布式训练')
    parser.add_argument('--local_rank', type=int, default=0, help='本地rank')
    
    args = parser.parse_args()
    
    # 加载配置
    if args.config and os.path.exists(args.config):
        with open(args.config, 'r') as f:
            config = json.load(f)
    else:
        config = create_default_config()
    
    # 设置环境变量
    if args.distributed:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    
    try:
        # 创建训练器
        trainer = SGGNModelTrainer(
            config=config,
            data_dir=args.data_dir,
            output_dir=args.output_dir,
            use_distributed=args.distributed
        )
        
        # 开始训练
        trainer.train()
        
    except Exception as e:
        logger.error(f"训练过程中发生错误: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    finally:
        # 清理分布式训练
        if args.distributed:
            dist.destroy_process_group()
    
    return 0

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)