#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Video2EEG-SGGN-Diffusion模型消融实验脚本
基于用户提供的消融实验方案进行系统性评估

消融实验包括：
1. 无SGGN组件的基线扩散模型
2. 不同扩散步数（25、50、100步）的影响
3. 禁用空间图注意机制的影响

作者: 算法工程师
日期: 2025年1月12日
"""

import os
import sys
import json
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal, stats
from scipy.spatial.distance import cosine
from sklearn.metrics import mean_squared_error, mean_absolute_error
from pathlib import Path
import logging
from typing import Dict, List, Tuple, Optional, Union
import warnings
from collections import defaultdict
import pandas as pd
from tqdm import tqdm
import copy

warnings.filterwarnings('ignore')

# 导入模型和数据集
from video2eeg_sggn_diffusion_model import Video2EEGSGGNDiffusion, create_video2eeg_sggn_diffusion
from train_sggn_diffusion import SGGNEEGVideoDataset
from improved_inference_sggn import EEGQualityEvaluator

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

class BaselineDiffusionModel(nn.Module):
    """
    基线扩散模型（无SGGN组件）
    """
    
    def __init__(self, config: Dict):
        super().__init__()
        self.config = config
        
        # 视频编码器
        self.video_encoder = nn.Sequential(
            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3)),
            nn.ReLU(),
            nn.AdaptiveAvgPool3d((None, 8, 8)),
            nn.Flatten(start_dim=2),
            nn.Linear(64 * 8 * 8, config.get('video_feature_dim', 512))
        )
        
        # 时间嵌入
        self.time_embedding = nn.Sequential(
            nn.Linear(1, 128),
            nn.ReLU(),
            nn.Linear(128, config.get('video_feature_dim', 512))
        )
        
        # 简单的扩散网络（无图结构）
        self.diffusion_net = nn.Sequential(
            nn.Linear(config.get('eeg_channels', 14) + config.get('video_feature_dim', 512) * 2, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, config.get('eeg_channels', 14))
        )
        
    def forward(self, video_frames, eeg_data, timesteps):
        batch_size, seq_len = eeg_data.shape[0], eeg_data.shape[2]
        
        # 编码视频特征
        video_features = self.video_encoder(video_frames)  # (B, T, D)
        video_features = torch.mean(video_features, dim=1)  # (B, D)
        
        # 时间嵌入
        time_emb = self.time_embedding(timesteps.float().unsqueeze(-1))  # (B, D)
        
        # 简单的逐时间步处理
        predicted_noise = []
        for t in range(seq_len):
            eeg_t = eeg_data[:, :, t]  # (B, C)
            
            # 拼接特征
            combined_features = torch.cat([
                eeg_t, video_features, time_emb
            ], dim=-1)
            
            # 预测噪声
            noise_t = self.diffusion_net(combined_features)  # (B, C)
            predicted_noise.append(noise_t)
        
        predicted_noise = torch.stack(predicted_noise, dim=2)  # (B, C, T)
        
        return predicted_noise

class SGGNModelWithoutSpatialAttention(Video2EEGSGGNDiffusion):
    """
    禁用空间图注意机制的SGGN模型
    """
    
    def __init__(self, config: Dict):
        super().__init__(config)
        # 将空间注意力替换为恒等映射
        self.spatial_attention = nn.Identity()

class AblationStudyEngine:
    """
    消融实验引擎
    """
    
    def __init__(self, 
                 data_dirs: List[str],
                 output_dir: str = "./ablation_study_output",
                 device: str = 'auto'):
        """
        初始化消融实验引擎
        
        Args:
            data_dirs: 数据目录列表
            output_dir: 输出目录
            device: 设备类型
        """
        self.data_dirs = data_dirs
        self.output_dir = Path(output_dir)
        
        # 设置设备
        if device == 'auto':
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)
        
        logger.info(f"使用设备: {self.device}")
        
        # 创建输出目录
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 模型配置
        self.model_config = {
            'video_feature_dim': 512,
            'eeg_channels': 14,
            'signal_length': 250,
            'num_diffusion_steps': 1000,
            'hidden_dim': 256,
            'dropout': 0.1
        }
        
        # 创建质量评估器
        self.evaluator = EEGQualityEvaluator()
        
        # 创建数据加载器
        self.test_loader = self.create_test_dataloader()
        
        logger.info(f"消融实验引擎初始化完成")
    
    def create_test_dataloader(self) -> DataLoader:
        """
        创建测试数据加载器
        """
        # 尝试从多个数据目录加载数据
        test_dataset = None
        for data_dir in self.data_dirs:
            try:
                test_dataset = SGGNEEGVideoDataset(
                    data_dir, 
                    split='test',
                    use_graph_da=False
                )
                if len(test_dataset) > 0:
                    logger.info(f"从 {data_dir} 加载测试数据: {len(test_dataset)} 样本")
                    break
            except Exception as e:
                logger.warning(f"从 {data_dir} 加载数据失败: {e}")
                continue
        
        if test_dataset is None or len(test_dataset) == 0:
            raise ValueError("无法从任何数据目录加载测试数据")
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        return test_loader
    
    def create_baseline_model(self) -> nn.Module:
        """
        创建基线模型（无SGGN组件）
        """
        model = BaselineDiffusionModel(self.model_config)
        model = model.to(self.device)
        model.train()  # 使用训练模式以保持一致性
        return model
    
    def create_full_sggn_model(self) -> nn.Module:
        """
        创建完整的SGGN模型
        """
        model = create_video2eeg_sggn_diffusion(self.model_config)
        model = model.to(self.device)
        model.train()
        return model
    
    def create_sggn_without_spatial_attention(self) -> nn.Module:
        """
        创建无空间注意力的SGGN模型
        """
        model = SGGNModelWithoutSpatialAttention(self.model_config)
        model = model.to(self.device)
        model.train()
        return model
    
    def generate_eeg_with_steps(self, model: nn.Module, video_frames: torch.Tensor, 
                               num_steps: int = 50) -> torch.Tensor:
        """
        使用指定步数生成EEG信号
        
        Args:
            model: 模型
            video_frames: 视频帧 (B, T, C, H, W)
            num_steps: 扩散步数
        
        Returns:
            生成的EEG信号 (B, C, T)
        """
        with torch.no_grad():
            video_frames = video_frames.to(self.device)
            batch_size = video_frames.shape[0]
            
            # 创建随机时间步
            t = torch.randint(0, num_steps, (batch_size,)).to(self.device)
            
            # 创建噪声EEG作为输入
            noisy_eeg = torch.randn(
                batch_size, 
                self.model_config['eeg_channels'], 
                self.model_config['signal_length']
            ).to(self.device)
            
            # 前向传播
            output = model(video_frames, noisy_eeg, t)
            
            # 处理输出
            if isinstance(output, tuple):
                predicted_noise = output[0]
                generated_eeg = noisy_eeg - predicted_noise
            else:
                generated_eeg = output
            
            # 数值稳定性处理
            if torch.isnan(generated_eeg).any() or torch.isinf(generated_eeg).any():
                logger.warning("生成的EEG包含NaN或无穷大值，使用随机噪声替代")
                generated_eeg = torch.randn_like(generated_eeg) * 0.1
            
            # 裁剪异常值
            generated_eeg = torch.clamp(generated_eeg, -10.0, 10.0)
            
            return generated_eeg
    
    def run_single_experiment(self, model: nn.Module, experiment_name: str, 
                            num_steps: int = 50, num_samples: int = 20) -> Dict:
        """
        运行单个消融实验
        
        Args:
            model: 实验模型
            experiment_name: 实验名称
            num_steps: 扩散步数
            num_samples: 测试样本数
        
        Returns:
            实验结果字典
        """
        logger.info(f"开始实验: {experiment_name} (步数: {num_steps}, 样本数: {num_samples})")
        
        results = {
            'experiment_name': experiment_name,
            'num_steps': num_steps,
            'metrics': [],
            'inference_times': [],
            'generated_eegs': [],
            'reference_eegs': []
        }
        
        sample_count = 0
        total_inference_time = 0
        
        for batch_idx, batch in enumerate(tqdm(self.test_loader, desc=f"{experiment_name} 进度")):
            if sample_count >= num_samples:
                break
            
            try:
                # 获取数据
                video_frames = batch['video']  # (1, T, C, H, W)
                reference_eeg = batch['eeg']   # (1, C, T)
                
                # 记录推理时间
                start_time = time.time()
                
                # 生成EEG
                generated_eeg = self.generate_eeg_with_steps(model, video_frames, num_steps)
                
                inference_time = time.time() - start_time
                total_inference_time += inference_time
                
                # 转换为numpy数组
                generated_eeg_np = generated_eeg.cpu().numpy()[0]  # (C, T)
                reference_eeg_np = reference_eeg.cpu().numpy()[0]  # (C, T)
                
                # 评估质量
                metrics = self.evaluator.evaluate_quality(
                    generated_eeg_np, reference_eeg_np
                )
                
                # 保存结果
                results['metrics'].append(metrics)
                results['inference_times'].append(inference_time)
                results['generated_eegs'].append(generated_eeg_np)
                results['reference_eegs'].append(reference_eeg_np)
                
                sample_count += 1
                
            except Exception as e:
                logger.error(f"处理样本 {sample_count} 时发生错误: {e}")
                continue
        
        # 计算平均指标
        if results['metrics']:
            avg_metrics = {}
            for key in results['metrics'][0].keys():
                values = [m[key] for m in results['metrics'] if key in m]
                if values:
                    avg_metrics[f'avg_{key}'] = np.mean(values)
                    avg_metrics[f'std_{key}'] = np.std(values)
            
            results['average_metrics'] = avg_metrics
            results['total_inference_time'] = total_inference_time
            results['avg_inference_time'] = total_inference_time / len(results['metrics'])
        
        logger.info(f"实验 {experiment_name} 完成: {len(results['metrics'])} 个样本")
        if 'average_metrics' in results:
            logger.info(f"平均MSE: {results['average_metrics'].get('avg_mse', 'N/A'):.6f}")
            logger.info(f"平均相关性: {results['average_metrics'].get('avg_mean_correlation', 'N/A'):.6f}")
            logger.info(f"平均推理时间: {results['avg_inference_time']:.3f}s")
        
        return results
    
    def run_full_ablation_study(self, num_samples: int = 20) -> Dict[str, Dict]:
        """
        运行完整的消融实验
        
        Args:
            num_samples: 每个实验的样本数
        
        Returns:
            所有实验结果
        """
        logger.info(f"开始完整消融实验，每个实验 {num_samples} 个样本")
        
        all_results = {}
        
        # 1. 完整SGGN模型（基准）
        logger.info("=== 实验1: 完整SGGN模型 ===")
        full_model = self.create_full_sggn_model()
        all_results['full_sggn'] = self.run_single_experiment(
            full_model, "完整SGGN模型", num_steps=50, num_samples=num_samples
        )
        
        # 2. 基线扩散模型（无SGGN组件）
        logger.info("=== 实验2: 基线扩散模型（无SGGN） ===")
        baseline_model = self.create_baseline_model()
        all_results['baseline_no_sggn'] = self.run_single_experiment(
            baseline_model, "基线扩散模型", num_steps=50, num_samples=num_samples
        )
        
        # 3. 不同扩散步数实验
        logger.info("=== 实验3: 不同扩散步数 ===")
        step_counts = [25, 100]
        for steps in step_counts:
            logger.info(f"--- 扩散步数: {steps} ---")
            model = self.create_full_sggn_model()
            all_results[f'sggn_{steps}_steps'] = self.run_single_experiment(
                model, f"SGGN模型({steps}步)", num_steps=steps, num_samples=num_samples
            )
        
        # 4. 无空间图注意机制
        logger.info("=== 实验4: 无空间图注意机制 ===")
        no_spatial_model = self.create_sggn_without_spatial_attention()
        all_results['sggn_no_spatial'] = self.run_single_experiment(
            no_spatial_model, "SGGN模型(无空间注意)", num_steps=50, num_samples=num_samples
        )
        
        logger.info("完整消融实验完成")
        
        return all_results
    
    def create_ablation_table(self, all_results: Dict[str, Dict]) -> pd.DataFrame:
        """
        创建消融实验结果表格
        
        Args:
            all_results: 所有实验结果
        
        Returns:
            结果表格
        """
        table_data = []
        
        # 定义实验顺序和名称映射
        experiment_order = [
            ('full_sggn', '完整SGGN模型'),
            ('baseline_no_sggn', '基线扩散模型(无SGGN)'),
            ('sggn_25_steps', 'SGGN模型(25步)'),
            ('sggn_100_steps', 'SGGN模型(100步)'),
            ('sggn_no_spatial', 'SGGN模型(无空间注意)')
        ]
        
        for exp_key, exp_name in experiment_order:
            if exp_key in all_results:
                result = all_results[exp_key]
                if 'average_metrics' in result:
                    metrics = result['average_metrics']
                    table_data.append({
                        '实验配置': exp_name,
                        'MSE': f"{metrics.get('avg_mse', 0):.4f}",
                        '相关性': f"{metrics.get('avg_mean_correlation', 0):.4f}",
                        '平均推理时间(s)': f"{result.get('avg_inference_time', 0):.2f}",
                        '扩散步数': result.get('num_steps', 50)
                    })
        
        df = pd.DataFrame(table_data)
        return df
    
    def generate_visualizations(self, all_results: Dict[str, Dict]):
        """
        生成可视化图表
        
        Args:
            all_results: 所有实验结果
        """
        viz_dir = self.output_dir / "visualizations"
        viz_dir.mkdir(exist_ok=True)
        
        # 1. 性能对比柱状图
        self.plot_performance_comparison(all_results, viz_dir)
        
        # 2. 推理时间对比
        self.plot_inference_time_comparison(all_results, viz_dir)
        
        # 3. EEG信号质量对比
        self.plot_eeg_quality_comparison(all_results, viz_dir)
        
        # 4. 详细指标雷达图
        self.plot_radar_chart(all_results, viz_dir)
        
        logger.info(f"可视化图表已保存到: {viz_dir}")
    
    def plot_performance_comparison(self, all_results: Dict[str, Dict], output_dir: Path):
        """
        绘制性能对比柱状图
        """
        # 提取数据
        experiments = []
        mse_values = []
        corr_values = []
        
        experiment_names = {
            'full_sggn': '完整SGGN',
            'baseline_no_sggn': '基线模型',
            'sggn_25_steps': 'SGGN(25步)',
            'sggn_100_steps': 'SGGN(100步)',
            'sggn_no_spatial': 'SGGN(无空间注意)'
        }
        
        for exp_key, exp_name in experiment_names.items():
            if exp_key in all_results and 'average_metrics' in all_results[exp_key]:
                metrics = all_results[exp_key]['average_metrics']
                experiments.append(exp_name)
                mse_values.append(metrics.get('avg_mse', 0))
                corr_values.append(metrics.get('avg_mean_correlation', 0))
        
        # 创建子图
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # MSE对比
        bars1 = ax1.bar(experiments, mse_values, color='skyblue', alpha=0.7)
        ax1.set_title('MSE对比', fontsize=14, fontweight='bold')
        ax1.set_ylabel('MSE')
        ax1.tick_params(axis='x', rotation=45)
        ax1.grid(True, alpha=0.3)
        
        # 添加数值标签
        for bar, value in zip(bars1, mse_values):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.4f}', ha='center', va='bottom')
        
        # 相关性对比
        bars2 = ax2.bar(experiments, corr_values, color='lightcoral', alpha=0.7)
        ax2.set_title('相关性对比', fontsize=14, fontweight='bold')
        ax2.set_ylabel('相关性')
        ax2.tick_params(axis='x', rotation=45)
        ax2.grid(True, alpha=0.3)
        
        # 添加数值标签
        for bar, value in zip(bars2, corr_values):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.0001,
                    f'{value:.4f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig(output_dir / "performance_comparison.png", dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_inference_time_comparison(self, all_results: Dict[str, Dict], output_dir: Path):
        """
        绘制推理时间对比图
        """
        experiments = []
        times = []
        steps = []
        
        experiment_names = {
            'full_sggn': '完整SGGN(50步)',
            'baseline_no_sggn': '基线模型(50步)',
            'sggn_25_steps': 'SGGN(25步)',
            'sggn_100_steps': 'SGGN(100步)',
            'sggn_no_spatial': 'SGGN(无空间注意,50步)'
        }
        
        for exp_key, exp_name in experiment_names.items():
            if exp_key in all_results:
                result = all_results[exp_key]
                experiments.append(exp_name)
                times.append(result.get('avg_inference_time', 0))
                steps.append(result.get('num_steps', 50))
        
        # 创建图表
        fig, ax = plt.subplots(figsize=(12, 6))
        
        # 根据步数设置颜色
        colors = ['green' if s == 25 else 'blue' if s == 50 else 'red' for s in steps]
        
        bars = ax.bar(experiments, times, color=colors, alpha=0.7)
        ax.set_title('推理时间对比', fontsize=14, fontweight='bold')
        ax.set_ylabel('平均推理时间 (秒)')
        ax.tick_params(axis='x', rotation=45)
        ax.grid(True, alpha=0.3)
        
        # 添加数值标签
        for bar, time_val in zip(bars, times):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                   f'{time_val:.2f}s', ha='center', va='bottom')
        
        # 添加图例
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor='green', alpha=0.7, label='25步'),
            Patch(facecolor='blue', alpha=0.7, label='50步'),
            Patch(facecolor='red', alpha=0.7, label='100步')
        ]
        ax.legend(handles=legend_elements, loc='upper right')
        
        plt.tight_layout()
        plt.savefig(output_dir / "inference_time_comparison.png", dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_eeg_quality_comparison(self, all_results: Dict[str, Dict], output_dir: Path):
        """
        绘制EEG信号质量对比图
        """
        # 选择一个实验的样本进行可视化
        if 'full_sggn' not in all_results:
            return
        
        full_result = all_results['full_sggn']
        baseline_result = all_results.get('baseline_no_sggn', {})
        
        if not full_result.get('generated_eegs') or not baseline_result.get('generated_eegs'):
            return
        
        # 选择第一个样本
        sample_idx = 0
        reference_eeg = full_result['reference_eegs'][sample_idx]
        full_sggn_eeg = full_result['generated_eegs'][sample_idx]
        baseline_eeg = baseline_result['generated_eegs'][sample_idx] if baseline_result.get('generated_eegs') else None
        
        # 选择前4个通道进行可视化
        num_channels = min(4, reference_eeg.shape[0])
        time_axis = np.arange(reference_eeg.shape[1]) / 250.0  # 假设采样率250Hz
        
        fig, axes = plt.subplots(num_channels, 1, figsize=(15, 2*num_channels))
        if num_channels == 1:
            axes = [axes]
        
        for ch in range(num_channels):
            axes[ch].plot(time_axis, reference_eeg[ch], label='参考EEG', alpha=0.8, linewidth=1.5)
            axes[ch].plot(time_axis, full_sggn_eeg[ch], label='完整SGGN', alpha=0.8, linewidth=1.5)
            if baseline_eeg is not None:
                axes[ch].plot(time_axis, baseline_eeg[ch], label='基线模型', alpha=0.8, linewidth=1.5)
            
            axes[ch].set_title(f'通道 {ch+1} - EEG信号对比')
            axes[ch].set_xlabel('时间 (s)')
            axes[ch].set_ylabel('幅值')
            axes[ch].legend()
            axes[ch].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(output_dir / "eeg_quality_comparison.png", dpi=300, bbox_inches='tight')
        plt.close()
    
    def plot_radar_chart(self, all_results: Dict[str, Dict], output_dir: Path):
        """
        绘制详细指标雷达图
        """
        # 选择要显示的指标
        metrics_to_plot = ['avg_mse', 'avg_mean_correlation', 'avg_mae']
        metric_labels = ['MSE', '相关性', 'MAE']
        
        # 提取数据
        experiment_data = {}
        experiment_names = {
            'full_sggn': '完整SGGN',
            'baseline_no_sggn': '基线模型',
            'sggn_no_spatial': 'SGGN(无空间注意)'
        }
        
        for exp_key, exp_name in experiment_names.items():
            if exp_key in all_results and 'average_metrics' in all_results[exp_key]:
                metrics = all_results[exp_key]['average_metrics']
                values = []
                for metric in metrics_to_plot:
                    value = metrics.get(metric, 0)
                    # 对于MSE和MAE，使用倒数进行归一化（值越小越好）
                    if metric in ['avg_mse', 'avg_mae']:
                        value = 1 / (1 + value)  # 归一化到0-1
                    values.append(value)
                experiment_data[exp_name] = values
        
        if not experiment_data:
            return
        
        # 创建雷达图
        angles = np.linspace(0, 2 * np.pi, len(metric_labels), endpoint=False).tolist()
        angles += angles[:1]  # 闭合图形
        
        fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))
        
        colors = ['blue', 'red', 'green']
        for i, (exp_name, values) in enumerate(experiment_data.items()):
            values += values[:1]  # 闭合图形
            ax.plot(angles, values, 'o-', linewidth=2, label=exp_name, color=colors[i % len(colors)])
            ax.fill(angles, values, alpha=0.25, color=colors[i % len(colors)])
        
        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(metric_labels)
        ax.set_ylim(0, 1)
        ax.set_title('消融实验指标雷达图', size=16, fontweight='bold', pad=20)
        ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
        ax.grid(True)
        
        plt.tight_layout()
        plt.savefig(output_dir / "radar_chart.png", dpi=300, bbox_inches='tight')
        plt.close()
    
    def save_results(self, all_results: Dict[str, Dict]):
        """
        保存实验结果
        
        Args:
            all_results: 所有实验结果
        """
        # 保存详细结果
        results_file = self.output_dir / "ablation_results.json"
        
        # 准备可序列化的结果
        serializable_results = {}
        for exp_name, result in all_results.items():
            serializable_results[exp_name] = {
                'experiment_name': result.get('experiment_name', ''),
                'num_steps': result.get('num_steps', 0),
                'average_metrics': result.get('average_metrics', {}),
                'total_inference_time': result.get('total_inference_time', 0),
                'avg_inference_time': result.get('avg_inference_time', 0),
                'num_samples': len(result.get('metrics', []))
            }
        
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(serializable_results, f, indent=2, ensure_ascii=False)
        
        # 创建并保存表格
        table_df = self.create_ablation_table(all_results)
        table_file = self.output_dir / "ablation_table.csv"
        table_df.to_csv(table_file, index=False, encoding='utf-8')
        
        # 保存LaTeX格式的表格
        latex_file = self.output_dir / "ablation_table.tex"
        with open(latex_file, 'w', encoding='utf-8') as f:
            f.write(table_df.to_latex(index=False, escape=False))
        
        logger.info(f"实验结果已保存:")
        logger.info(f"  详细结果: {results_file}")
        logger.info(f"  CSV表格: {table_file}")
        logger.info(f"  LaTeX表格: {latex_file}")
        
        # 打印表格到控制台
        print("\n=== 消融实验结果表格 ===")
        print(table_df.to_string(index=False))
        print("\n")

def main():
    """
    主函数
    """
    parser = argparse.ArgumentParser(description='Video2EEG-SGGN-Diffusion消融实验')
    parser.add_argument('--data_dirs', nargs='+', 
                       default=[
                           '/data0/GYF-projects/EEG2Video/Vidoe2EEG/Video2EEG-diffusion/enhanced_processed_data',
                           '/data0/GYF-projects/EEG2Video/EEG2Video/data',
                           '/data0/GYF-projects/EEG2Video/dataset'
                       ],
                       help='数据目录列表')
    parser.add_argument('--output_dir', default='./ablation_study_output',
                       help='输出目录')
    parser.add_argument('--num_samples', type=int, default=20,
                       help='每个实验的样本数')
    parser.add_argument('--device', default='auto',
                       help='设备类型 (auto, cuda, cpu)')
    
    args = parser.parse_args()
    
    # 创建消融实验引擎
    engine = AblationStudyEngine(
        data_dirs=args.data_dirs,
        output_dir=args.output_dir,
        device=args.device
    )
    
    # 运行完整消融实验
    all_results = engine.run_full_ablation_study(num_samples=args.num_samples)
    
    # 生成可视化
    engine.generate_visualizations(all_results)
    
    # 保存结果
    engine.save_results(all_results)
    
    logger.info("消融实验完成！")

if __name__ == "__main__":
    main()