#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Run All Datasets Ablation Study
运行所有数据集的消融实验对比
"""

import os
import sys
import logging
import argparse
import subprocess
from datetime import datetime
from typing import List, Dict

def setup_logging(verbose: bool = False):
    """设置日志配置"""
    level = logging.DEBUG if verbose else logging.INFO
    
    os.makedirs('outputs', exist_ok=True)
    log_file = f'outputs/run_all_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
    
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler(log_file, mode='w')
        ]
    )

class MultiDatasetAblationRunner:
    """多数据集消融实验运行器"""
    
    def __init__(self, verbose: bool = False, iterations: int = 20):
        """初始化运行器
        
        Args:
            verbose: 是否启用详细日志
            iterations: 每个消融方法的轮次数量
        """
        self.verbose = verbose
        self.iterations = iterations
        
        # 三个数据集
        self.datasets = ['cifar10', 'cifar100', 'ImageNet16-120']
        
        # 检查环境
        self._check_environment()
        
        logging.info("🚀 多数据集消融实验运行器初始化完成")
    
    def _check_environment(self):
        """检查运行环境"""
        logger = logging.getLogger(__name__)
        
        # 检查OpenAI API密钥
        if not os.getenv('OPENAI_API_KEY'):
            logger.error("❌ OPENAI_API_KEY环境变量未设置")
            logger.info("💡 请设置: export OPENAI_API_KEY='your-api-key'")
            sys.exit(1)
        
        # 检查run.py是否存在
        if not os.path.exists('run.py'):
            logger.error("❌ run.py文件未找到")
            sys.exit(1)
        
        logger.info("✅ 环境检查通过")
    
    def run_single_dataset_ablation(self, dataset: str) -> bool:
        """运行单个数据集的消融实验
        
        Args:
            dataset: 数据集名称
            
        Returns:
            是否成功完成
        """
        logger = logging.getLogger(__name__)
        
        logger.info(f"\n{'='*100}")
        logger.info(f"🔬 开始数据集 {dataset.upper()} 的消融实验")
        logger.info(f"{'='*100}")
        
        # 构建命令
        cmd = [
            sys.executable, 'run.py',
            '--mode', 'ablation',
            '--dataset', dataset,
            '--iterations', str(self.iterations)
        ]
        
        if self.verbose:
            cmd.append('--verbose')
        
        logger.info(f"📝 执行命令: {' '.join(cmd)}")
        
        # 记录开始时间
        import time
        start_time = time.time()
        
        try:
            # 运行命令
            result = subprocess.run(cmd, capture_output=False, text=True, check=True)
            
            # 记录结束时间
            end_time = time.time()
            duration = end_time - start_time
            
            logger.info(f"✅ 数据集 {dataset} 消融实验成功完成")
            logger.info(f"⏱️ 用时: {duration/60:.1f} 分钟")
            
            return True
            
        except subprocess.CalledProcessError as e:
            end_time = time.time()
            duration = end_time - start_time
            
            logger.error(f"❌ 数据集 {dataset} 消融实验失败")
            logger.error(f"⏱️ 用时: {duration/60:.1f} 分钟")
            logger.error(f"🔍 错误代码: {e.returncode}")
            
            return False
        
        except KeyboardInterrupt:
            logger.info(f"\n⚠️ 用户中断了数据集 {dataset} 的消融实验")
            return False
    
    def run_all_datasets(self) -> Dict[str, bool]:
        """运行所有数据集的消融实验
            
        Returns:
            各数据集的运行结果
        """
        logger = logging.getLogger(__name__)
        
        logger.info("🚀 开始所有数据集的消融实验")
        logger.info(f"📊 数据集列表: {', '.join(self.datasets)}")
        logger.info(f"🔧 每方法架构数: 由轮次自动决定")
        logger.info(f"⏱️ 预计总时间: {len(self.datasets) * 2:.0f} 小时")
        
        results = {}
        successful_count = 0
        
        import time
        overall_start_time = time.time()
        
        for i, dataset in enumerate(self.datasets, 1):
            logger.info(f"\n📈 总进度: {i}/{len(self.datasets)}")
            
            success = self.run_single_dataset_ablation(dataset)
            results[dataset] = success
            
            if success:
                successful_count += 1
                logger.info(f"🎉 {dataset} 完成")
            else:
                logger.error(f"💥 {dataset} 失败")
            
            # 显示当前统计
            logger.info(f"📊 当前统计: 成功 {successful_count}/{i}")
            
            # 如果不是最后一个数据集，短暂等待
            if i < len(self.datasets):
                logger.info("⏸️ 等待 10 秒后继续下一个数据集...")
                time.sleep(10)
        
        # 最终统计
        overall_end_time = time.time()
        total_duration = overall_end_time - overall_start_time
        
        logger.info(f"\n{'='*100}")
        logger.info(f"🎊 所有数据集消融实验完成!")
        logger.info(f"⏱️ 总用时: {total_duration/3600:.1f} 小时")
        logger.info(f"📊 最终统计:")
        logger.info(f"   ✅ 成功: {successful_count}/{len(self.datasets)}")
        logger.info(f"   ❌ 失败: {len(self.datasets) - successful_count}/{len(self.datasets)}")
        logger.info(f"   📈 成功率: {successful_count/len(self.datasets)*100:.1f}%")
        
        # 详细结果
        logger.info(f"\n📋 详细结果:")
        for dataset, success in results.items():
            status = "✅ 成功" if success else "❌ 失败"
            logger.info(f"   {dataset:15s}: {status}")
        
        if successful_count == len(self.datasets):
            logger.info(f"\n🎉 所有数据集都成功完成消融实验!")
        elif successful_count > 0:
            logger.info(f"\n⚠️ 部分数据集完成，请检查失败的数据集")
        else:
            logger.error(f"\n💥 所有数据集都失败了，请检查配置和环境")
        
        logger.info(f"\n💾 所有结果保存在 outputs/ 目录中")
        
        return results
    
    def generate_combined_report(self, results: Dict[str, bool]):
        """生成合并报告
        
        Args:
            results: 各数据集的运行结果
        """
        logger = logging.getLogger(__name__)
        
        logger.info("📄 生成合并报告...")
        
        # 创建报告文件
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        report_file = f"outputs/combined_ablation_report_{timestamp}.txt"
        
        with open(report_file, 'w', encoding='utf-8') as f:
            f.write("Combined Multi-Dataset Ablation Study Report\n")
            f.write("=" * 50 + "\n\n")
            
            f.write(f"Execution Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Total Datasets: {len(self.datasets)}\n")
            f.write(f"Successful Datasets: {sum(results.values())}\n")
            f.write(f"Failed Datasets: {len(self.datasets) - sum(results.values())}\n\n")
            
            f.write("DATASET RESULTS:\n")
            f.write("-" * 30 + "\n")
            
            for dataset, success in results.items():
                status = "SUCCESS" if success else "FAILED"
                f.write(f"{dataset:15s}: {status}\n")
            
            f.write(f"\nEXPERIMENT DETAILS:\n")
            f.write("-" * 30 + "\n")
            f.write("• Methods Compared: 4\n")
            f.write("  1. Proposed Method (Complete)\n")
            f.write("  2. Without Partitioning\n")
            f.write("  3. Without LLM (uses PEA)\n")
            f.write("  4. Without ZC Ensemble (uses Synflow)\n\n")
            
            f.write("• Hardware Devices: 6\n")
            f.write("  - EdgeGPU, Raspi4, EdgeTPU\n")
            f.write("  - Pixel3, Eyeriss, FPGA\n\n")
            
            f.write("• Evaluation Metrics:\n")
            f.write("  - Hypervolume Ratio (HV)\n")
            f.write("  - Inverted Generational Distance (IGD)\n")
            f.write("  - Pareto Front Size\n\n")
            
            f.write("OUTPUT FILES:\n")
            f.write("-" * 30 + "\n")
            for dataset in self.datasets:
                if results[dataset]:
                    f.write(f"{dataset}:\n")
                    f.write(f"  • ablation_results_{dataset}.json\n")
                    f.write(f"  • unified_ablation_comparison_{dataset}.png\n")
                    f.write(f"  • ablation_metrics_summary_{dataset}.json\n")
                    f.write(f"  • ablation_report_{dataset}.txt\n\n")
        
        logger.info(f"📄 合并报告已保存: {report_file}")

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='运行所有数据集的消融实验')
    # num-architectures参数已移除，架构数量由轮次自动决定
    parser.add_argument('--verbose', action='store_true',
                       help='启用详细日志')
    parser.add_argument('--iterations', type=int, default=20,
                       help='每个消融方法的轮次数量')
    parser.add_argument('--datasets', type=str, nargs='+', 
                       choices=['cifar10', 'cifar100', 'ImageNet16-120'],
                       help='指定要运行的数据集 (默认: 全部)')
    
    args = parser.parse_args()
    
    # 设置日志
    setup_logging(args.verbose)
    logger = logging.getLogger(__name__)
    
    logger.info("🚀 多数据集消融实验启动")
    
    try:
        # 初始化运行器
        runner = MultiDatasetAblationRunner(args.verbose, args.iterations)
        
        # 如果指定了特定数据集，更新数据集列表
        if args.datasets:
            runner.datasets = args.datasets
            logger.info(f"🎯 指定数据集: {', '.join(runner.datasets)}")
        
        # 运行所有数据集的消融实验
        results = runner.run_all_datasets()
        
        # 生成合并报告
        runner.generate_combined_report(results)
        
        # 最终状态
        successful_count = sum(results.values())
        if successful_count == len(runner.datasets):
            logger.info("🎉 所有数据集消融实验成功完成!")
            sys.exit(0)
        elif successful_count > 0:
            logger.warning("⚠️ 部分数据集完成，请检查失败的数据集")
            sys.exit(1)
        else:
            logger.error("💥 所有数据集都失败了")
            sys.exit(2)
    
    except KeyboardInterrupt:
        logger.info("\n⚠️ 用户中断")
        sys.exit(0)
    except Exception as e:
        logger.error(f"❌ 程序执行失败: {e}")
        import traceback
        logger.debug(traceback.format_exc())
        sys.exit(1)

if __name__ == "__main__":
    main()
